diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000..6c2eea8d02 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,73 @@ +# This workflow will upload a Python Package to Release asset +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/publish.yml +name: flexkv ci + +on: + pull_request: + branches: [ "main", "dev"] + push: + branches: [ "main", "dev"] + +# Needed to create wheel and upload assets +permissions: + contents: write + +jobs: + build: + name: Build Wheel + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ['ubuntu-22.04'] + python-version: ['3.10'] + pytorch-version: ['2.6.0'] + cuda-version: ['12.4'] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Linux Env + if: ${{ runner.os == 'Linux' }} + run: | + bash -x .github/workflows/scripts/env.sh + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install CUDA ${{ matrix.cuda-version }} + run: | + bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} + + - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} + run: | + bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} + + - name: Build wheel + shell: bash + env: + TORCH_CUDA_ARCH_LIST: "8.9 9.0+PTX" + MAX_JOBS: 4 + run: | + ./build.sh --release + + - name: Get Date and Time + run: | + echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + echo "time=$(date +'%H-%M-%S')" >> $GITHUB_ENV + + - name: Upload to cos + uses: shallwefootball/s3-upload-action@master + with: + aws_key_id: ${{ secrets.COS_SECRET_ID }} + aws_secret_access_key: ${{ secrets.COS_SECRET_KEY }} + aws_bucket: ${{ secrets.COS_BUCKET }} + endpoint: ${{ secrets.COS_ENDPOINT }} + source_dir: dist + destination_dir: flexkv/${{ env.date }}/${{ env.time }} diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh new file mode 100755 index 0000000000..3e4d7c8b7d --- /dev/null +++ b/.github/workflows/scripts/cuda-install.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/cuda-install.sh + +# Replace '.' with '-' ex: 11.8 -> 11-8 +cuda_version=$(echo "$1" | tr "." "-") +# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 +OS=$(echo "$2" | tr -d ".\-") + +# Installs CUDA +wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb" +sudo dpkg -i cuda-keyring_1.1-1_all.deb +rm cuda-keyring_1.1-1_all.deb +sudo apt -qq update +sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}" +sudo apt clean + +# Test nvcc +PATH=/usr/local/cuda-$1/bin:${PATH} +nvcc --version + +# Log gcc, g++, c++ versions +gcc --version +g++ --version +c++ --version diff --git a/.github/workflows/scripts/env.sh b/.github/workflows/scripts/env.sh new file mode 100755 index 0000000000..299f281236 --- /dev/null +++ b/.github/workflows/scripts/env.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/env.sh + +# This file installs common linux environment tools + +export LANG=C.UTF-8 + +sudo apt-get update && \ +sudo apt-get install -y --no-install-recommends \ + software-properties-common + +sudo apt-get install -y --no-install-recommends \ + build-essential \ + liburing-dev \ + git \ + cmake + +# Remove github bloat files to free up disk space +sudo rm -rf "/usr/local/share/boost" +sudo rm -rf "$AGENT_TOOLSDIRECTORY" +sudo rm -rf "/usr/share/dotnet" diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh new file mode 100755 index 0000000000..559043d412 --- /dev/null +++ b/.github/workflows/scripts/pytorch-install.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/pytorch-install.sh + +python_executable=python$1 +pytorch_version=$2 +cuda_version=$3 + +# Install torch +$python_executable -m pip install numpy ninja cython wheel typing typing-extensions dataclasses setuptools && conda clean -ya +$python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}" + +# Print version information +$python_executable --version +$python_executable -c "import torch; print('PyTorch:', torch.__version__)" +$python_executable -c "import torch; print('CUDA:', torch.version.cuda)" +$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" diff --git a/.gitignore b/.gitignore index fd0f58caa6..03727c4ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,6 @@ cover/ # mypy .mypy_cache/ + +# VSCode +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..0fe668086a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,30 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [1.0.0] - 2025-09-11 + +### Added +- C++ radix tree for fast match, need set "index_accel": true in cache_config +- sync kernel launch +- a huge change that move cache engine to a library for accelerator(vLLM e.g.) to use instead of server-client mode. + This accelerate the get and put when no KVCache is matched. This version includes breaking API changes and is not backward compatible. +- add evict_ratio, need set "evict_ratio": 0.05 in cache_config +- reducing the bubble inner the launch kernel +- add vLLM 0.10.1.1 adapter + +### Fixed +- cython release package + + +## [0.1.0] - 2025-08-29 + +### Init +- init version +- add license + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..301a6fe36a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing to Mooncake + +Thank you for your interest in contributing to FlexKV! + +## PR Title and Classification +Use a prefixed PR title to indicate the type of changes. Please use one of the following: + +- `[bugfix]` for bugfixes +- `[feature]` for new features +- `[test]` for test cases +- `[ci/build]` for build or continuous integration improvements +- `[doc]` for documentation fixes +- `[misc]` for PRs that do not fit the above categories. Please use this sparingly. \ No newline at end of file diff --git a/README.md b/README.md index 56875811e3..ed78dbca43 100644 --- a/README.md +++ b/README.md @@ -14,23 +14,9 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) ./build.sh ``` -### Use FlexKV with vLLM (v0.8.4) +### Use FlexKV with vLLM -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: - -```bash -# Start FlexKV as server -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# Start vLLM as client -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# Start benchmark -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. - -> **Note**: The current script is only compatible with the `main` branch. Support for the latest features in the `dev` branch is under development. +See [docs/vllm_adapter/README_en.md](docs/vllm_adapter/README_en.md) ## Design Architecture @@ -88,6 +74,7 @@ FlexKV performs: - The main branch is the stable branch, which maintains already tested commits. Please pull from main branch if you need stable code. - The dev branch is the development branch, which contains newer features. Please branch from and merge into dev if you need new features or are developing new functionality. - The bugfix branch is for bug fixes, maintaining urgent bugs that need immediate resolution or documentation that requires prompt updates. If you need to fix a bug or update documentation urgently, please branch from and merge into the bugfix branch. +- The stable branch refers to the previous main branch state, intended only for rollback or extremely conservative use cases (e.g., production deployment). Its use is discouraged. ## Roadmap diff --git a/README_zh.md b/README_zh.md index 8223a5d9c0..0618a83220 100644 --- a/README_zh.md +++ b/README_zh.md @@ -16,21 +16,7 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE ### 以 vLLM 为例使用 FlexKV -在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: - -```bash -# 启动 FlexKV 作为服务端 -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# 启动 vLLM 作为客户端 -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# 启动性能测试 -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` -在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch`,测试方法同上。 - -> **注意**:当前脚本仅适配 `main` 分支。`dev` 分支的最新特性支持脚本正在开发中。 +见[docs/vllm_adapter/README_zh.md](docs/vllm_adapter/README_zh.md) ## 设计框架 @@ -88,6 +74,7 @@ FlexKV 在处理 *get* 请求时: - main 为稳定分支,维护已经测试过的commit。需要稳定的代码请从此分支拉取。 - dev 为开发分支,维护较新特性。需要新特性和开发新特性请从此分支拉取和合入。 - bugfix 为bug分支,维护需要立即解决的bug或需要立即更新的文档。需要解决bug和立即更新的文档请从此分支拉取和合入。 +- stable 为上一个版本的main分支位置,仅用于回滚以及极其保守的情况使用(如产品化)。不鼓励使用此版本。 ## Roadmap diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000..3eefcb9dd5 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +1.0.0 diff --git a/benchmarks/benchmark_kvmanager.py b/benchmarks/benchmark_kvmanager.py deleted file mode 100644 index c1cedb9519..0000000000 --- a/benchmarks/benchmark_kvmanager.py +++ /dev/null @@ -1,273 +0,0 @@ -import os -import tempfile -from multiprocessing import Process -import argparse -import json -import time -from dataclasses import dataclass - -import torch - -from flexkv.server.client import KVDPClient, KVTPClient -from flexkv.server.server import KVServer, SchedulerServer -from flexkv.common.config import ModelConfig, CacheConfig -from flexkv.common.storage import KVCacheLayoutType, KVCacheLayout -from flexkv.common.debug import flexkv_logger -from utils import load_config - -flexkv_logger.set_level("INFO") - - -@dataclass -class BenchmarkConfig: - num_layers_to_transfer: int - batch_size: int - sequence_length: int - cache_ratio: float - -def run_server(model_config, cache_config, server_recv_port): - """Run server process""" - kvserver = KVServer(model_config, cache_config, server_recv_port) - kvserver.run() - time.sleep(10) - -def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): - """Run tp_client process""" - device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) - - num_gpu_blocks = cache_config.num_gpu_blocks - - gpu_kv_layout = KVCacheLayout( - type=cache_config.gpu_kv_layout_type, - num_layer=model_config.num_layers, - num_block=num_gpu_blocks, - tokens_per_block=cache_config.tokens_per_block, - num_head=model_config.num_kv_heads, - head_size=model_config.head_size, - is_mla=model_config.use_mla, - ) - - # Create GPU blocks for this tp_rank in the tp_client process - gpu_blocks_for_tp = [] - for _ in range(model_config.num_layers): - gpu_blocks_for_tp.append( - torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) - ) - tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) - # Keep the process running - while True: - time.sleep(1) - -def shutdown_tp_client(tp_client_processes): - for tp_process in tp_client_processes: - if tp_process.is_alive(): - tp_process.terminate() - tp_process.join(timeout=5) - if tp_process.is_alive(): - print(f"Force killing tp_client process {tp_process.pid}") - tp_process.kill() - tp_process.join(timeout=2) - -class FlexkvWrapper: - def __init__(self, model_config, cache_config, server_recv_port): - self.model_config = model_config - self.cache_config = cache_config - self.server_recv_port = server_recv_port - - self.use_scheduler_server = model_config.dp_size == 1 - if self.use_scheduler_server: - self.launch_scheduler_server() - else: - self.launch_server() - - def launch_server(self): - def server_process(): - kvserver = KVServer(self.model_config, self.cache_config, self.server_recv_port) - kvserver.run() - time.sleep(10) - self.server_process = Process( - target=server_process, - daemon=False - ) - self.server_process.start() - time.sleep(5) - self.dp_client = KVDPClient(self.server_recv_port, self.model_config) - - def launch_scheduler_server(self): - self.scheduler_server = SchedulerServer(self.model_config, self.cache_config, self.server_recv_port) - self.scheduler_server.start_server_thread() - time.sleep(10) - - @property - def dp_client_id(self): - if self.use_scheduler_server: - return 0 - else: - return self.dp_client.dp_client_id - - def put_async(self, token_ids, slot_mapping, token_mask=None): - if self.use_scheduler_server: - return self.scheduler_server.put_async(token_ids, slot_mapping, token_mask) - else: - return self.dp_client.put_async(token_ids, slot_mapping, token_mask) - - def get_async(self, token_ids, slot_mapping, token_mask=None): - if self.use_scheduler_server: - return self.scheduler_server.get_async(token_ids, slot_mapping, token_mask) - else: - return self.dp_client.get_async(token_ids, slot_mapping, token_mask) - - def wait(self, request_ids): - if self.use_scheduler_server: - return self.scheduler_server.wait(request_ids) - else: - return self.dp_client.wait(request_ids) - - def try_wait(self, request_ids): - if self.use_scheduler_server: - return self.scheduler_server.try_wait(request_ids) - else: - return self.dp_client.try_wait(request_ids) - - def check_running(self): - if self.use_scheduler_server: - return self.scheduler_server.check_running() - else: - return self.dp_client.check_running() - - def shutdown(self): - if not self.use_scheduler_server: - try: - # Send a shutdown request to the server - self.dp_client.shutdown() - # Wait a bit for graceful shutdown - time.sleep(3) - except Exception as e: - print(f"Error sending shutdown request: {e}") - if self.server_process.is_alive(): - self.server_process.terminate() - self.server_process.join(timeout=10) - if self.server_process.is_alive(): - print(f"Force killing server process {self.server_process.pid}") - self.server_process.kill() - self.server_process.join(timeout=5) - if self.server_recv_port.startswith('ipc://'): - temp_file = self.server_recv_port[6:] # Remove 'ipc://' prefix - try: - if os.path.exists(temp_file): - os.unlink(temp_file) - except Exception as e: - print(f"Error cleaning up temporary file: {e}") - else: - self.scheduler_server.shutdown() - -def benchmark_kvmanager(model_config, cache_config, benchmark_config, server_recv_port): - if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): - raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " - f"the number of available GPUs {torch.cuda.device_count()}") - print(f"{model_config = }") - print(f"{cache_config = }") - print(f"{benchmark_config = }") - flexkv_wrapper = FlexkvWrapper(model_config, cache_config, server_recv_port) - - tp_client_processes = [] - - sequence_length = benchmark_config.sequence_length - batch_size = benchmark_config.batch_size - num_required_gpu_blocks = sequence_length * batch_size // cache_config.tokens_per_block - cache_config.num_gpu_blocks = num_required_gpu_blocks - print(f"allocate {num_required_gpu_blocks} gpu blocks for benchmark") - for tp_rank in range(model_config.tp_size): - tp_client_process = Process( - target=run_tp_client, - args=(flexkv_wrapper.dp_client_id, tp_rank, server_recv_port, - model_config, cache_config), - daemon=True - ) - tp_client_process.start() - tp_client_processes.append(tp_client_process) - time.sleep(5) - - batch_sequence_tensor = [] - batch_slot_mapping = [] - cache_length = int(sequence_length * benchmark_config.cache_ratio) - - # generate requests - for i in range(batch_size): - batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length, ), dtype=torch.int64)) - batch_slot_mapping.append(torch.arange(i * sequence_length, (i+1) * sequence_length, dtype=torch.int64)) - - while not flexkv_wrapper.check_running(): - time.sleep(0.1) - print("waiting for flexkv wrapper to be ready") - # benchmark put - start_time = time.time() - put_ids = [] - if benchmark_config.cache_ratio > 0: - for i in range(batch_size): - put_ids.append(flexkv_wrapper.put_async(batch_sequence_tensor[i][:cache_length], - batch_slot_mapping[i][:cache_length], - token_mask=None)) - put_result = flexkv_wrapper.wait(put_ids) - end_time = time.time() - time.sleep(1) - elapsed_time_put = end_time - start_time - put_tokens = 0 - for _, return_mask in put_result.items(): - put_tokens += return_mask.sum().item() - transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 - transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put - print(f"put {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " - f"time: {elapsed_time_put*1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") - - #benchmark get - start_time = time.time() - get_ids = [] - for i in range(batch_size): - get_ids.append(flexkv_wrapper.get_async(batch_sequence_tensor[i], - batch_slot_mapping[i], - token_mask=None)) - get_result = flexkv_wrapper.wait(get_ids) - end_time = time.time() - elapsed_time_get = end_time - start_time - cached_tokens = 0 - all_tokens = 0 - for _, return_mask in get_result.items(): - cached_tokens += return_mask.sum().item() - all_tokens += len(return_mask) - transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 - transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get - print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " - f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " - f"time: {elapsed_time_get*1000:.2f}ms, bandwidth: {transfer_bandwidth_get:.2f} GB/s") - - shutdown_tp_client(tp_client_processes) - flexkv_wrapper.shutdown() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="benchmarks/example_config.json") - # benchmark config - parser.add_argument("--num-layers", type=int, default=-1) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--sequence-length", type=int, default=1024) - parser.add_argument("--cache-ratio", type=float, default=1) - return parser.parse_args() - -if __name__ == "__main__": - args = parse_args() - benchmark_config = BenchmarkConfig( - num_layers_to_transfer=args.num_layers, - batch_size=args.batch_size, - sequence_length=args.sequence_length, - cache_ratio=args.cache_ratio - ) - model_config, cache_config = load_config(args.config) - #cache_config.num_cpu_blocks = 8192 - 2048 - # pad sequence length to divisible by tokens_per_block - benchmark_config.sequence_length = \ - ((benchmark_config.sequence_length - 1) // cache_config.tokens_per_block + 1) * cache_config.tokens_per_block - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - benchmark_kvmanager(model_config, cache_config, benchmark_config, server_recv_port) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py new file mode 100644 index 0000000000..397030becd --- /dev/null +++ b/benchmarks/benchmark_single_batch.py @@ -0,0 +1,191 @@ +import tempfile +from multiprocessing import Process +import argparse +import time +from dataclasses import dataclass + +import torch + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout +from flexkv.common.debug import flexkv_logger +from flexkv.common.config import ModelConfig, CacheConfig +from utils import load_config +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +flexkv_logger.set_level("INFO") + + +@dataclass +class BenchmarkConfig: + num_layers_to_transfer: int + batch_size: int + sequence_length: int + cache_ratio: float + clear_cpu_cache: bool + +def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): + """Run tp_client process""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) + + num_gpu_blocks = cache_config.num_gpu_blocks + + gpu_kv_layout = KVCacheLayout( + type=cache_config.gpu_kv_layout_type, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + # Keep the process running + while True: + time.sleep(1) + +def shutdown_tp_client(tp_client_processes): + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + +def benchmark_flexkv(model_config: ModelConfig, + cache_config: CacheConfig, + benchmark_config: BenchmarkConfig, + gpu_register_port: str, + server_recv_port: str): + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " + f"the number of available GPUs {torch.cuda.device_count()}") + print(f"{benchmark_config = }") + kvmanager = KVManager(model_config, cache_config, gpu_register_port, server_recv_port) + kvmanager.start() + + tp_client_processes = [] + + sequence_length = benchmark_config.sequence_length + batch_size = benchmark_config.batch_size + num_required_gpu_blocks = sequence_length * batch_size // cache_config.tokens_per_block + cache_config.num_gpu_blocks = num_required_gpu_blocks + print(f"allocate {num_required_gpu_blocks} gpu blocks for benchmark") + for tp_rank in range(model_config.tp_size): + tp_client_process = Process( + target=run_tp_client, + args=(0, tp_rank, gpu_register_port, + model_config, cache_config), + daemon=True + ) + tp_client_process.start() + tp_client_processes.append(tp_client_process) + + while not kvmanager.is_ready(): + time.sleep(3) + flexkv_logger.info("waiting for flexkv to be ready") + flexkv_logger.info("flexkv is ready") + + batch_sequence_tensor = [] + batch_slot_mapping = [] + cache_length = int(sequence_length * benchmark_config.cache_ratio) + + # generate requests + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length, ), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i+1) * sequence_length, dtype=torch.int64)) + + # benchmark put + start_time = time.time() + batch_put_ids = [] + if benchmark_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async(batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + if benchmark_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put + print(f"put {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put*1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], + token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + kvmanager.launch(batch_get_ids, batch_slot_mapping) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get + print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time*1000:.2f}ms, " + f"e2e time: {elapsed_time_get*1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + + shutdown_tp_client(tp_client_processes) + kvmanager.shutdown() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="benchmarks/example_config.json") + # benchmark config + parser.add_argument("--num-layers", type=int, default=-1) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--sequence-length", type=int, default=1024) + parser.add_argument("--cache-ratio", type=float, default=1) + parser.add_argument("--clear-cpu-cache", action="store_true") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + benchmark_config = BenchmarkConfig( + num_layers_to_transfer=args.num_layers, + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache + ) + model_config, cache_config = load_config(args.config) + #cache_config.num_cpu_blocks = 8192 - 2048 + # pad sequence length to divisible by tokens_per_block + benchmark_config.sequence_length = \ + ((benchmark_config.sequence_length - 1) // cache_config.tokens_per_block + 1) * cache_config.tokens_per_block + import uuid + gpu_register_port = f"ipc:///tmp/flexkv_gpu_{uuid.uuid4().hex[:8]}" + server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" + + benchmark_flexkv(model_config, cache_config, benchmark_config, gpu_register_port, server_recv_port) diff --git a/benchmarks/benchmark_workers.py b/benchmarks/benchmark_workers.py index f9c50f8bc7..8c1c2182e3 100644 --- a/benchmarks/benchmark_workers.py +++ b/benchmarks/benchmark_workers.py @@ -9,7 +9,7 @@ import torch -from flexkv.common.transfer import TransferOp, TransferType, TransferDescriptor +from flexkv.common.transfer import TransferOp, TransferType from flexkv.transfer.worker import GPUCPUTransferWorker, CPUSSDDiskTransferWorker, WorkerHandle, tpGPUCPUTransferWorker from flexkv.storage.allocator import CPUAllocator, GPUAllocator, SSDAllocator from flexkv.common.storage import KVCacheLayoutType, KVCacheLayout @@ -17,7 +17,7 @@ from flexkv.common.debug import flexkv_logger -flexkv_logger.set_level("OFF") +# flexkv_logger.set_level("OFF") @dataclass class BenchmarkConfig: @@ -50,11 +50,10 @@ def make_configs(args: dict) -> Tuple[ModelConfig, CacheConfig, BenchmarkConfig] def create_cpu_gpu_worker( model_config: ModelConfig, - cache_config: CacheConfig, - bench_config: BenchmarkConfig) -> Tuple[WorkerHandle, mp.Queue]: + cache_config: CacheConfig) -> Tuple[WorkerHandle, mp.Queue]: mp.set_start_method('spawn', force=True) cpu_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.cpu_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_cpu_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -85,7 +84,6 @@ def create_cpu_gpu_worker( finished_ops_queue = mp.Queue() if model_config.tp_size == 1: worker_handle = GPUCPUTransferWorker.create_worker( - worker_id=0, finished_ops_queue=finished_ops_queue, gpu_blocks=gpu_handles[0].get_tensor_handle_list(), cpu_blocks=cpu_handle.get_tensor(), @@ -100,7 +98,6 @@ def create_cpu_gpu_worker( ) else: worker_handle = tpGPUCPUTransferWorker.create_worker( - worker_id=0, finished_ops_queue=finished_ops_queue, gpu_blocks=[handle.get_tensor_handle_list() for handle in gpu_handles], cpu_blocks=cpu_handle.get_tensor(), @@ -121,11 +118,10 @@ def create_cpu_gpu_worker( def create_cpu_ssd_worker( model_config: ModelConfig, - cache_config: CacheConfig, - bench_config: BenchmarkConfig) -> Tuple[WorkerHandle, mp.Queue]: + cache_config: CacheConfig) -> Tuple[WorkerHandle, mp.Queue]: mp.set_start_method('spawn', force=True) cpu_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.cpu_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_cpu_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -133,7 +129,7 @@ def create_cpu_ssd_worker( head_size=model_config.head_size, ) ssd_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.ssd_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_ssd_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -153,7 +149,6 @@ def create_cpu_ssd_worker( ) finished_ops_queue = mp.Queue() worker_handle = CPUSSDDiskTransferWorker.create_worker( - worker_id=10, finished_ops_queue=finished_ops_queue, cpu_blocks=cpu_handle.get_tensor(), ssd_files=ssd_handle.get_file_list(), @@ -195,29 +190,25 @@ def bench_worker(args): shuffle_ids = bench_config.shuffle_ids if transfer_type == TransferType.H2D or transfer_type == TransferType.D2H: - worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config, bench_config) + worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config) elif transfer_type == TransferType.H2DISK or transfer_type == TransferType.DISK2H: - worker_handle, finished_ops_queue = create_cpu_ssd_worker(model_config, cache_config, bench_config) + worker_handle, finished_ops_queue = create_cpu_ssd_worker(model_config, cache_config) else: raise ValueError(f"Unsupported transfer type: {transfer_type} for benchmark, " f"currently only support {TransferType.H2D.name}, {TransferType.D2H.name}, " f"{TransferType.H2DISK.name}, {TransferType.DISK2H.name}") if shuffle_ids: - block_ids = torch.randperm(num_blocks_to_transfer) + block_ids = torch.randperm(num_blocks_to_transfer).numpy() else: - block_ids = torch.arange(num_blocks_to_transfer) + block_ids = torch.arange(num_blocks_to_transfer).numpy() transfer_op = TransferOp( transfer_type=transfer_type, layer_id=0, layer_granularity=num_layers_to_transfer, - src_descriptor=TransferDescriptor( - physical_block_ids=block_ids, - ), - dst_descriptor=TransferDescriptor( - physical_block_ids=block_ids, - ), + src_block_ids=block_ids, + dst_block_ids=block_ids, graph_id=0, dp_id=0, successors=[], @@ -226,8 +217,8 @@ def bench_worker(args): if transfer_type == TransferType.DISK2H: tmp_op = copy.deepcopy(transfer_op) tmp_op.transfer_type = TransferType.H2DISK - tmp_op.src_descriptor = transfer_op.dst_descriptor - tmp_op.dst_descriptor = transfer_op.src_descriptor + tmp_op.src_block_ids = transfer_op.dst_block_ids + tmp_op.dst_block_ids = transfer_op.src_block_ids launch_transfer(worker_handle, finished_ops_queue, tmp_op) for _ in range(warmup_round): launch_transfer(worker_handle, finished_ops_queue, transfer_op) diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 630bdded0b..d4854557c3 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -14,11 +14,10 @@ "enable_remote": false, "tokens_per_block": 16, "use_gds": false, - "use_pinned_memory": true, "gpu_kv_layout_type": "LAYERWISE", - "cpu_kv_layout_type": "LAYERWISE", - "ssd_kv_layout_type": "LAYERWISE", - "remote_kv_layout_type": "LAYERWISE", + "cpu_kv_layout_type": "BLOCKWISE", + "ssd_kv_layout_type": "BLOCKWISE", + "remote_kv_layout_type": "BLOCKWISE", "num_cpu_blocks": 2048, "num_ssd_blocks": 4096, "num_remote_blocks": null, @@ -28,7 +27,7 @@ "transfer_sms_d2h": 8, "max_blocks_per_file": 32000, "ssd_cache_dir": "./ssd_cache1/", - "ssd_cache_iouring_entries": 512, + "ssd_cache_iouring_entries": 32, "ssd_cache_iouring_flags": 0, "remote_cache_size_mode": "file_size", "remote_file_size": null, @@ -40,6 +39,8 @@ "trace_file_path": "./flexkv_trace.log", "trace_max_file_size_mb": 100, "trace_max_files": 5, - "trace_flush_interval_ms": 1000 + "trace_flush_interval_ms": 1000, + "evict_ratio": 0.05, + "index_accel": true } } diff --git a/build.sh b/build.sh index 8b34f27df4..b976fec7d9 100755 --- a/build.sh +++ b/build.sh @@ -40,13 +40,6 @@ echo "=== Setting BUILD_LIB_PATH to $BUILD_LIB_PATH ===" cd .. -echo "=== Installing package with pip ===" -if [ "$BUILD_TYPE" = "debug" ]; then - FLEXKV_DEBUG=1 pip install --no-build-isolation -e . -else - FLEXKV_DEBUG=0 pip install --no-build-isolation -e . -fi - # Set LD_LIBRARY_PATH for immediate use export LD_LIBRARY_PATH=$BUILD_LIB_PATH:$LD_LIBRARY_PATH echo "Added $BUILD_LIB_PATH to LD_LIBRARY_PATH for current session" @@ -69,3 +62,11 @@ fi echo "=== Build and installation completed successfully in ${BUILD_TYPE} mode ===" echo "You can now run tests directly without setting LD_LIBRARY_PATH manually" + +if [ "$BUILD_TYPE" = "debug" ]; then + FLEXKV_DEBUG=1 pip install -v --no-build-isolation -e . +elif [ "$BUILD_TYPE" = "release" ]; then + FLEXKV_DEBUG=0 python setup.py bdist_wheel -v +else + FLEXKV_DEBUG=0 pip install -v --no-build-isolation -e . +fi diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index a991954be0..c03ad74c40 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -19,6 +19,7 @@ #include "tp_transfer_thread_group.h" #include "transfer.cuh" #include "transfer_ssd.h" +#include "radix_tree.h" namespace py = pybind11; @@ -143,12 +144,10 @@ PYBIND11_MODULE(c_ext, m) { py::class_(m, "TPTransferThreadGroup") .def(py::init> &, - torch::Tensor &, int>()) + torch::Tensor &, int, torch::Tensor &, torch::Tensor &, torch::Tensor &>()) .def("tp_group_transfer", &flexkv::TPTransferThreadGroup::tp_group_transfer, - py::arg("gpu_block_id_tensor"), py::arg("gpu_kv_stride_in_bytes"), - py::arg("gpu_block_stride_in_bytes"), - py::arg("gpu_chunk_size_in_bytes"), py::arg("cpu_block_id_tensor"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), py::arg("cpu_kv_stride_in_bytes"), py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), @@ -195,4 +194,36 @@ PYBIND11_MODULE(c_ext, m) { "Call Pcfs::write from C++", py::arg("file_nodeid"), py::arg("offset"), py::arg("buffer"), py::arg("size"), py::arg("thread_id")); #endif + + py::class_(m, "CRadixTreeIndex") + .def(py::init()) + .def("is_empty", &flexkv::CRadixTreeIndex::is_empty) + .def("reset", &flexkv::CRadixTreeIndex::reset) + .def("lock", &flexkv::CRadixTreeIndex::lock, py::arg("node")) + .def("unlock", &flexkv::CRadixTreeIndex::unlock, py::arg("node")) + .def("set_ready", &flexkv::CRadixTreeIndex::set_ready, + py::arg("node"), py::arg("ready"), py::arg("ready_length")) + .def("insert", &flexkv::CRadixTreeIndex::insert, py::return_value_policy::reference, + py::arg("physical_block_ids"), py::arg("block_hashes"), py::arg("num_blocks"), + py::arg("num_insert_blocks"), py::arg("ready") = true, py::arg("node") = nullptr, + py::arg("num_matched_blocks") = -1, py::arg("last_node_matched_length") = -1) + .def("evict", &flexkv::CRadixTreeIndex::evict, py::arg("evicted_blocks"), py::arg("num_evicted")) + .def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks) + .def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks) + .def("total_ready_blocks", &flexkv::CRadixTreeIndex::total_ready_blocks) + .def("match_prefix", &flexkv::CRadixTreeIndex::match_prefix, + py::arg("block_hashes"), py::arg("num_blocks"), py::arg("update_cache_info")); + + py::class_(m, "CRadixNode") + .def(py::init()) + .def("size", &flexkv::CRadixNode::size); + + py::class_>(m, "CMatchResult") + .def(py::init *>()) + .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) + .def_readonly("last_node", &flexkv::CMatchResult::last_node) + .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) + .def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks) + .def_readonly("num_matched_blocks", &flexkv::CMatchResult::num_matched_blocks) + .def_readonly("last_node_matched_length", &flexkv::CMatchResult::last_node_matched_length); } diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp new file mode 100644 index 0000000000..32dc8ab0e7 --- /dev/null +++ b/csrc/radix_tree.cpp @@ -0,0 +1,268 @@ +#include +#include +#include +#include +#include +#include + +#include "cache_utils.h" +#include "radix_tree.h" + +namespace flexkv { + +CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) { + assert(index != nullptr); + + this->on_leaf = false; + this->parent = nullptr; + this->index = index; + this->ready = ready; + this->lock_cnt = lock_cnt; + + struct timeval now; + gettimeofday(&now, nullptr); + last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + + index->inc_node_count(); +} + +CRadixNode::~CRadixNode() { + assert(parent == nullptr); + + block_hashes.clear(); + physical_blocks.clear(); + children.clear(); + + index->dec_node_count(); +} + +CRadixNode *CRadixNode::split(int prefix_length) { + assert(prefix_length < size()); + assert(prefix_length > 0); + assert(parent != nullptr); + + auto new_node = new CRadixNode(index, is_ready(), 0); + new_node->set_time(get_time()); + new_node->set_parent(parent); + get_index()->add_node(new_node); + + auto &new_block_hashes = new_node->get_block_hashes(); + auto &new_physical_blocks = new_node->get_physical_blocks(); + + new_block_hashes.insert(new_block_hashes.end(), block_hashes.cbegin(), block_hashes.cbegin() + prefix_length); + new_physical_blocks.insert(new_physical_blocks.end(), physical_blocks.cbegin(), physical_blocks.cbegin() + prefix_length); + + block_hashes.erase(block_hashes.begin(), block_hashes.begin() + prefix_length); + physical_blocks.erase(physical_blocks.begin(), physical_blocks.begin() + prefix_length); + + parent->set_child(new_node->get_head_hash(), new_node); + new_node->set_parent(parent); + new_node->set_child(get_head_hash(), this); + + set_parent(new_node); + return new_node; +} + +void CRadixNode::merge_child() { + auto child = children.begin()->second; + + assert(get_num_children() == 1); + assert(child->is_leaf()); + + block_hashes.insert(block_hashes.end(), child->get_block_hashes().cbegin(), + child->get_block_hashes().cend()); + physical_blocks.insert(physical_blocks.end(), child->get_physical_blocks().cbegin(), + child->get_physical_blocks().cend()); + + set_time(std::max(get_time(), child->get_time())); + children.clear(); + + child->clear_parent(); + index->remove_leaf(child); + index->remove_node(child); +} + +std::deque *CRadixNode::shrink(int length) { + assert(length < size()); + assert(length > 0); + assert(is_leaf()); + assert(in_use() == false); + + auto remaining_length = size() - length; + auto shrink_blocks = new std::deque(); + + shrink_blocks->insert(shrink_blocks->end(), physical_blocks.begin() + remaining_length, physical_blocks.end()); + + block_hashes.erase(block_hashes.begin() + remaining_length, block_hashes.end()); + physical_blocks.erase(physical_blocks.begin() + remaining_length, physical_blocks.end()); + + return shrink_blocks; +} + +CRadixNode *CRadixTreeIndex::insert(torch::Tensor &physical_block_ids, + torch::Tensor &block_hashes, int num_blocks, int num_insert_blocks, bool ready, + CRadixNode *last_node, int num_matched_blocks, int last_node_matched_length) { + if (num_insert_blocks == -1) { + num_insert_blocks = num_blocks; + } + assert(num_insert_blocks >= 0); + assert(num_insert_blocks <= num_blocks); + assert(physical_block_ids.ndim() == 1); + + if (last_node == nullptr) { + auto match_result = match_prefix(block_hashes, num_blocks, true); + num_matched_blocks = match_result->num_matched_blocks; + last_node_matched_length = match_result->last_node_matched_length; + last_node = match_result->last_node; + } + + assert(last_node != nullptr); + assert(last_node_matched_length != 0 || is_root(last_node)); + assert(physical_block_ids.size() == num_insert_blocks - num_matched_blocks); + + if (num_matched_blocks >= num_insert_blocks) { + return nullptr; + } + + auto new_node = new CRadixNode(this, ready, 0); + auto &new_block_hashes = new_node->get_block_hashes(); + auto &new_physical_blocks = new_node->get_physical_blocks(); + + auto block_hashes_ptr = block_hashes.data_ptr(); + auto physical_block_ids_ptr = physical_block_ids.data_ptr(); + for (auto i = 0; i + num_matched_blocks < num_insert_blocks; i++) { + new_block_hashes.insert(new_block_hashes.end(), block_hashes_ptr[i+num_matched_blocks]); + new_physical_blocks.insert(new_physical_blocks.end(), physical_block_ids_ptr[i]); + } + + if (last_node_matched_length < last_node->size()) { + last_node->split(last_node_matched_length); + last_node = last_node->get_parent(); + assert(last_node != nullptr); + } + + if (last_node->is_leaf()) { + remove_leaf(last_node); + } + + new_node->set_parent(last_node); + last_node->set_child(new_node->get_head_hash(), new_node); + + add_node(new_node); + add_leaf(new_node); + return new_node; +} + +int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, int num_evicted) { + int64_t *evicted_blocks_ptr = evicted_blocks.data_ptr(); + int has_evicted = 0; + std::priority_queue, CRadixNode::Compare> candidate; + + for (auto it = leaf_list.begin(); it != leaf_list.end(); it++) { + if ((*it)->evictable()) { + candidate.push(*it); + } + } + + while ((has_evicted < num_evicted) && candidate.size()) { + auto node = candidate.top(); + candidate.pop(); + + if (node->size() > num_evicted - has_evicted) { + auto blocks = node->shrink(num_evicted - has_evicted); + for (auto it = blocks->begin(); it != blocks->end(); it++) { + evicted_blocks_ptr[has_evicted] = *it; + has_evicted++; + } + delete blocks; + } else { + auto parent = node->get_parent(); + auto &blocks = node->get_physical_blocks(); + + assert(parent != nullptr); + parent->remove_child(node->get_head_hash()); + + for (auto it = blocks.begin(); it != blocks.end(); it++) { + evicted_blocks_ptr[has_evicted] = *it; + has_evicted++; + } + + if (parent->is_leaf() && !is_root(parent)) { + add_leaf(parent); + if (parent->evictable()) { + candidate.push(parent); + } + } + + node->clear_parent(); + remove_leaf(node); + remove_node(node); + } + } + return has_evicted; +} + +std::shared_ptr CRadixTreeIndex::match_prefix( + torch::Tensor &block_hashes, int num_blocks, bool update_cache_info) { + auto current_node = root; + auto last_ready_node = root; + auto prefix_blocks_num = 0; + auto ready_prefix_blocks_num = 0; + auto last_node_matched_length = 0; + auto physical_blocks = new std::vector(); + auto block_hashes_ptr = block_hashes.data_ptr(); + HashType child_hash; + + while (prefix_blocks_num < num_blocks) { + if (update_cache_info) { + current_node->update_time(); + } + + child_hash = HashType(block_hashes_ptr[prefix_blocks_num + current_node->size()]); + if (current_node->lookup_child(child_hash)) { + if (current_node->is_ready()) { + last_ready_node = current_node; + ready_prefix_blocks_num += current_node->size(); + } + prefix_blocks_num += current_node->size(); + physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), + current_node->get_physical_blocks().end()); + current_node = current_node->get_child(child_hash); + } else { + auto matched_length = 0; + if (is_root(current_node) == false) { + auto cmp_length = std::min(current_node->size(), num_blocks - prefix_blocks_num); + auto left = 0; + auto right = cmp_length; + + while (left < right) { + auto mid = (left + right) / 2; + if (current_node->get_hash(mid) == HashType(block_hashes_ptr[prefix_blocks_num+mid])) { + left = mid + 1; + } else { + right = mid; + } + } + matched_length = left; + physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), + current_node->get_physical_blocks().begin() + matched_length); + } else { + matched_length = 0; + } + + if (current_node->is_ready()) { + last_ready_node = current_node; + ready_prefix_blocks_num += matched_length; + } + + last_node_matched_length = matched_length; + prefix_blocks_num += matched_length; + break; + } + } + + return std::make_shared(prefix_blocks_num, ready_prefix_blocks_num, last_node_matched_length, + last_ready_node, current_node, physical_blocks); +} + +} // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h new file mode 100644 index 0000000000..63560a3a8d --- /dev/null +++ b/csrc/radix_tree.h @@ -0,0 +1,346 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "cache_utils.h" + +namespace flexkv { + +class CRadixTreeIndex; + +class CRadixNode { +private: + bool on_leaf; + bool ready; + int lock_cnt; + time_t last_access_time; + + std::deque block_hashes; + std::deque physical_blocks; + std::map children; + + CRadixTreeIndex *index; + CRadixNode *parent; + +public: + CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt); + ~CRadixNode(); + + struct Compare { + bool operator() (CRadixNode *a, CRadixNode *b) { + return a->get_time() > b->get_time(); + } + }; + + bool get_leaf_state() { + return on_leaf; + } + + void set_leaf_state(bool on_leaf) { + this->on_leaf = on_leaf; + } + + CRadixTreeIndex *get_index() { + return index; + } + + void set_time(time_t time) { + last_access_time = time; + } + + time_t get_time() { + return last_access_time; + } + + void update_time() { + struct timeval now; + + gettimeofday(&now, nullptr); + last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + } + + CRadixNode *get_parent() { + return parent; + } + + void set_parent(CRadixNode *parent) { + this->parent = parent; + } + + void clear_parent() { + this->parent = nullptr; + } + + HashType get_hash(int pos) { + return HashType(block_hashes[pos]); + } + + HashType get_head_hash() { + if (size() > 0) { + return HashType(block_hashes[0]); + } else { + return HashType(0); + } + } + + int size() { + return block_hashes.size(); + } + + int get_num_children() { + return children.size(); + } + + std::deque &get_block_hashes() { + return block_hashes; + } + + std::deque &get_physical_blocks() { + return physical_blocks; + } + + bool lookup_child(HashType hash) { + auto iter = children.find(hash); + if (iter != children.end()) + return true; + else + return false; + } + + CRadixNode *get_child(HashType hash) { + return children.at(hash); + } + + void set_child(HashType hash, CRadixNode *node) { + children[hash] = node; + } + + void remove_child(HashType hash) { + children.erase(hash); + } + + bool is_leaf() { + return get_num_children() == 0; + } + + bool in_use() { + return lock_cnt > 0 || !ready; + } + + bool evictable() { + return is_leaf() && !in_use(); + } + + void lock() { + assert(lock_cnt >= 0); + lock_cnt++; + } + + void unlock() { + assert(lock_cnt > 0); + lock_cnt--; + } + + void set_ready(bool ready) { + this->ready = ready; + } + + bool is_ready() { + return ready; + } + + CRadixNode *split(int prefix_length); + std::deque *shrink(int length); + void merge_child(); +}; + +class CMatchResult { +public: + int num_ready_matched_blocks; + int num_matched_blocks; + int last_node_matched_length; + + CRadixNode *last_ready_node; + CRadixNode *last_node; + std::vector *physical_blocks; + + CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, + CRadixNode *_last_ready_node, CRadixNode *_last_node, std::vector *blocks) + : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), + last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), + last_node(_last_node), physical_blocks(blocks) { + } + + ~CMatchResult() { + delete physical_blocks; + }; +}; + +class CRadixTreeIndex { +private: + CRadixNode *root; + std::list node_list; + std::list leaf_list; + + int max_num_blocks; + int tokens_per_block; + int node_count; + +public: + CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000) { + this->tokens_per_block = tokens_per_block; + this->max_num_blocks = max_num_blocks; + this->node_count = 0; + + root = new CRadixNode(this, true, 0); + node_list.push_back(root); + } + + ~CRadixTreeIndex() { + leaf_list.clear(); + + while (node_list.size()) { + auto node = node_list.front(); + node->set_parent(nullptr); + node_list.pop_front(); + delete node; + } + + if (node_count) { + std::cerr << "CRadix Node count" << node_count << std::endl; + } + } + + void reset() { + leaf_list.clear(); + + while (node_list.size()) { + auto node = node_list.front(); + node->set_parent(nullptr); + node_list.pop_front(); + delete node; + } + + root = new CRadixNode(this, true, 0); + node_list.push_back(root); + } + + bool is_root(CRadixNode *node) { + return node == root; + } + + CRadixNode *get_root() { + return root; + } + + void remove_node(CRadixNode *node) { + assert(node != root); + assert(node->get_parent() == nullptr); + + node_list.remove(node); + delete node; + } + + void remove_leaf(CRadixNode *node) { + assert(node != root); + assert(node->get_leaf_state()); + + if (node->get_leaf_state() == false) { + return; + } + + leaf_list.remove(node); + node->set_leaf_state(false); + } + + void add_node(CRadixNode *node) { + assert(node != nullptr); + assert(node->get_parent() != nullptr); + node_list.push_back(node); + } + + void add_leaf(CRadixNode *node) { + assert(node != nullptr); + assert(node->get_leaf_state() == false); + + if (node->get_leaf_state() == true) { + return; + } + + leaf_list.push_back(node); + node->set_leaf_state(true); + } + + void lock(CRadixNode *node) { + node->lock(); + } + + void unlock(CRadixNode *node) { + node->unlock(); + } + + bool is_empty() { + return node_list.size() == 1; + } + + void inc_node_count() { + node_count++; + } + + void dec_node_count() { + node_count--; + } + + void set_ready(CRadixNode *node, bool ready = true, int ready_length = -1) { + node->set_ready(ready); + if (ready_length > 0) { + ready_length -= node->size(); + while (ready_length > 0) { + assert(node->get_parent() != nullptr); + node = node->get_parent(); + ready_length -= node->size(); + node->set_ready(true); + } + assert(ready_length == 0); + } + } + + int total_node_num() { + return node_list.size() - 1; + } + + int total_cached_blocks() { + auto total_blocks = 0; + + for (auto it = node_list.begin(); it != node_list.end(); it++) { + total_blocks += (*it)->size(); + } + return total_blocks; + } + + int total_ready_blocks() { + auto total_blocks = 0; + for (auto it = node_list.begin(); it != node_list.end(); it++) { + if ((*it)->is_ready()) { + total_blocks += (*it)->size(); + } + } + return total_blocks; + } + + int total_unready_blocks() { + return total_cached_blocks() - total_ready_blocks(); + } + + int evict(torch::Tensor &evicted_blocks, int num_evicted); + std::shared_ptr match_prefix(torch::Tensor &block_hashes, + int num_blocks, bool update_cache_info = true); + CRadixNode *insert(torch::Tensor &physical_block_ids, torch::Tensor &block_hashes, int num_blocks, + int num_insert_blocks, bool ready = true, CRadixNode *node = nullptr, int num_matched_blocks = -1, + int last_node_matched_length = -1); +}; + +} // namespace flexkv diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 617ac8abd2..06cb45c4fe 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -22,8 +22,31 @@ namespace flexkv { TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, - torch::Tensor &cpu_blocks, int dp_group_id) { + torch::Tensor &cpu_blocks, int dp_group_id, + torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor) { + num_gpus_ = num_gpus; + + gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t* kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); + int64_t* block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t* chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; + gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; + gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + } + + queues_.resize(num_gpus_); + mtxs_ = std::vector(num_gpus_); + cvs_ = std::vector(num_gpus_); + int num_layers = gpu_blocks[0].size(); cudaMallocHost((void **)&gpu_blocks_, num_gpus_ * num_layers * sizeof(void *)); @@ -41,15 +64,55 @@ TPTransferThreadGroup::TPTransferThreadGroup( cudaSetDevice(dp_group_id * num_gpus_ + i); cudaStreamCreate(&streams_[i]); } + // create the thread pool + stop_pool_=false; + for (int i = 0; i < num_gpus_; ++i) { + threads_.emplace_back([this, i]() { + int device_id = dp_group_id_ * num_gpus_ + i; + cudaSetDevice(device_id); // only once + + while (true) { + Task task; + { + std::unique_lock lk(mtxs_[i]); + cvs_[i].wait(lk, [&]{ return stop_pool_ || !queues_[i].empty(); }); + if (stop_pool_ && queues_[i].empty()) return; + + task = std::move(queues_[i].front()); + queues_[i].pop(); + } + task(); // + } + }); + } + +} + +TPTransferThreadGroup::~TPTransferThreadGroup() { + stop_pool_ = true; + for (auto& cv : cvs_) cv.notify_all(); + for (auto& t : threads_) if (t.joinable()) t.join(); + + cudaFreeHost(gpu_blocks_); + + delete[] gpu_kv_strides_in_bytes_; + delete[] gpu_block_strides_in_bytes_; + delete[] gpu_chunk_sizes_in_bytes_; } -TPTransferThreadGroup::~TPTransferThreadGroup() {} +std::future TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx, Task task) { + auto pkg = std::make_shared>(std::move(task)); + auto fut = pkg->get_future(); + { + std::lock_guard lk(mtxs_[gpu_idx]); + queues_[gpu_idx].emplace([pkg]{ (*pkg)(); }); + } + cvs_[gpu_idx].notify_one(); + return fut; +} void TPTransferThreadGroup::tp_group_transfer( const torch::Tensor &gpu_block_id_tensor, - const int64_t gpu_kv_stride_in_bytes, - const int64_t gpu_block_stride_in_bytes, - const int64_t gpu_chunk_size_in_bytes, const torch::Tensor &cpu_block_id_tensor, const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, @@ -60,11 +123,15 @@ void TPTransferThreadGroup::tp_group_transfer( std::atomic failed{false}; std::string error_msg; - threads_.clear(); - threads_.reserve(num_gpus_); + // threads_.clear(); + // threads_.reserve(num_gpus_); - for (int i = 0; i < num_gpus_; ++i) { - threads_.emplace_back([&, i]() { + // Barrier sync_point(num_gpus_); + std::vector> futures; + futures.reserve(num_gpus_); + + for (int i=0; i(gpu_blocks_ + i * num_layers + layer_id); void *cpu_ptr = cpu_blocks_; int64_t cpu_startoff_inside_chunks = - is_mla ? 0 : i * gpu_chunk_size_in_bytes; - cudaSetDevice(dp_group_id_ * num_gpus_ + i); + is_mla ? 0 : i * gpu_chunk_sizes_in_bytes_[i]; + flexkv::transfer_kv_blocks( - num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_layer_ptrs, gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, - cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, - cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, gpu_chunk_size_in_bytes, streams_[i], - transfer_sms, is_host_to_device, use_ce_transfer, is_mla); + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i], + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, gpu_chunk_sizes_in_bytes_[i], streams_[i], + transfer_sms, is_host_to_device, use_ce_transfer, is_mla + ); + cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { failed = true; @@ -95,12 +164,12 @@ void TPTransferThreadGroup::tp_group_transfer( failed = true; error_msg = e.what(); } - }); + + })); } - for (auto &t : threads_) { - if (t.joinable()) - t.join(); + for (auto &f : futures){ + f.get(); } if (failed) { diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 315c36c541..3d57e569c5 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -24,20 +24,22 @@ #include #include #include - +#include +#include +#include +#include namespace flexkv { - class TPTransferThreadGroup { public: TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, - torch::Tensor &cpu_blocks, int dp_group_id); + torch::Tensor &cpu_blocks, int dp_group_id, + torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor); ~TPTransferThreadGroup(); void tp_group_transfer(const torch::Tensor &gpu_block_id_tensor, - const int64_t gpu_kv_stride_in_bytes, - const int64_t gpu_block_stride_in_bytes, - const int64_t gpu_chunk_size_in_bytes, const torch::Tensor &cpu_block_id_tensor, const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, @@ -48,12 +50,25 @@ class TPTransferThreadGroup { const int layer_granularity, const bool is_mla); private: + using Task = std::function; + std::future enqueue_for_gpu(int gpu_idx, Task task); + int num_gpus_; int dp_group_id_; void **gpu_blocks_; void *cpu_blocks_; + + int64_t *gpu_kv_strides_in_bytes_; + int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_chunk_sizes_in_bytes_; + std::vector threads_; std::vector streams_; + + std::vector> queues_; + std::vector mtxs_; + std::vector cvs_; + std::atomic stop_pool_; }; } // namespace flexkv diff --git a/csrc/transfer_ssd.h b/csrc/transfer_ssd.h index f795f4de54..9d6f9ba57e 100644 --- a/csrc/transfer_ssd.h +++ b/csrc/transfer_ssd.h @@ -173,11 +173,10 @@ class IOUring { class SSDIOCTX { public: - SSDIOCTX(std::map>& ssd_files, - int num_devices, int iouring_entries, int iouring_flags) - : iouring(iouring_entries, iouring_flags), - fds_buffer_io(num_devices), - fds_direct_io(num_devices) { + SSDIOCTX(std::map> &ssd_files, int num_devices, + int iouring_entries, int iouring_flags) + : iouring(iouring_entries, iouring_flags), fds_buffer_io(num_devices), + fds_direct_io(num_devices) { int i, j, fd_buffer_io, fd_direct_io; @@ -190,7 +189,8 @@ class SSDIOCTX { fd_direct_io = open(ssd_files[i][j].c_str(), O_RDWR | O_DIRECT); if (fd_buffer_io < 0 || fd_direct_io < 0) { - std::cerr << "open file failed, path = " << ssd_files[i][j] << std::endl; + std::cerr << "open file failed, path = " << ssd_files[i][j] + << std::endl; throw std::runtime_error("Failed to open file"); } else { posix_fadvise(fd_buffer_io, 0, 0, POSIX_FADV_SEQUENTIAL); @@ -221,20 +221,14 @@ class SSDIOCTX { } } - int get_num_devices() { - return num_devices; - } + int get_num_devices() { return num_devices; } - int get_num_files_per_device() { - return num_files_per_device; - } + int get_num_files_per_device() { return num_files_per_device; } - IOUring &get_iouring() { - return iouring; - } + IOUring &get_iouring() { return iouring; } std::vector> &get_fds(bool is_read, bool is_direct) { - if (is_read && is_direct) { + if (is_direct) { return fds_direct_io; } else { return fds_buffer_io; @@ -250,15 +244,13 @@ class SSDIOCTX { std::vector> fds_direct_io; }; - void transfer_kv_blocks_ssd( - SSDIOCTX &ioctx, - const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, - const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids, - int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, - int64_t ssd_layer_stride_in_bytes, int64_t ssd_kv_stride_in_bytes, - int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read, - int num_blocks_per_file, int round_robin = 1, - int num_threads_per_device = 16, bool is_mla = false); + SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, + int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids, + const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes, + int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, + int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes, + int64_t block_stride_in_bytes, bool is_read, int num_blocks_per_file, + int round_robin = 1, int num_threads_per_device = 16, bool is_mla = false); } // namespace flexkv diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md new file mode 100644 index 0000000000..781b3ad3ee --- /dev/null +++ b/docs/vllm_adapter/README_en.md @@ -0,0 +1,84 @@ +# Using FlexKV in vLLM + +## Current Version vs. Legacy Version +In commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934), we introduced a major update: +**FlexKV has transitioned from a client-server architecture to a library function that inference acceleration engines (such as vLLM) can directly invoke**, reducing inter-process communication overhead. + +This change involves significant API adjustments. Therefore, please note: + +- **Version >= `1.0.0`**: Use the **current version API**; the vLLM patch is located in `examples/vllm_adaption/`. +- **Version == `0.1.0`**: Supports the **legacy version API**; the vLLM patch is located in `examples/vllm_adaption_legacy/`. + +--- + +## Current Version (>= 1.0.0) + +### Supported Versions +- FlexKV >= `1.0.0` +- vLLM versions >= `0.8.5` can generally follow this version for adaptation + +### Example +We provide an adaptation example based on **vLLM 0.10.1.1**: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +## Legacy Version (<= 0.1.0) – Not Recommended for Current Use + +### Supported Versions +- FlexKV <= `0.1.0` + +### Example +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: + +```bash +# Start FlexKV as server +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# Start vLLM as client +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# Start benchmark +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md new file mode 100644 index 0000000000..0e7ce7687e --- /dev/null +++ b/docs/vllm_adapter/README_zh.md @@ -0,0 +1,83 @@ +# 在 vLLM 中使用 FlexKV + +## 当前版本与 Legacy 版本说明 +在 commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934),我们更新了一个重要功能: + **FlexKV 从 client-server 模式,变为推理加速引擎(如 vLLM)可直接调用的库函数**,以减少进程间消息传递的开销。 +这一变更引发了较大的 API 调整。因此,请注意: + +- **版本 >= `1.0.0`**:应使用 **当前版本 API**,vLLM patch位于 `examples/vllm_adaption/`。 +- **版本 == `0.1.0`**:仅支持 **Legacy 版本 API**, vLLM patch位于`examples/vllm_adaption_legacy/`。 + +--- + +## 当前版本(>= 1.0.0) + +### 适用版本 +- FlexKV >= `1.0.0` +- vLLM 原则上>= `0.8.5`版本均可参考示例代码进行修改 + +### 示例 +我们提供了基于 **vLLM 0.10.1.1** 的适配示例: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +## Legacy版本(<= 0.1.0),目前的版本尽量不要使用 + +### 适用版本 +- FlexKV <= `0.1.0` + +### 示例 +在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: + +```bash +# 启动 FlexKV 作为服务端 +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# 启动 vLLM 作为客户端 +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# 启动性能测试 +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 diff --git a/examples/run_server.py b/examples/run_server.py index d5b6a182ec..48b24ecad1 100644 --- a/examples/run_server.py +++ b/examples/run_server.py @@ -12,16 +12,16 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - + # NAME - parser.add_argument("--enable-cpu", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-cpu", + action=argparse.BooleanOptionalAction, default=True) - parser.add_argument("--enable-ssd", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-ssd", + action=argparse.BooleanOptionalAction, default=False,) - parser.add_argument("--enable-remote", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-remote", + action=argparse.BooleanOptionalAction, default=False,) parser.add_argument("--model-path", type=str, help="model path", default="") parser.add_argument("--tp-size", type=int, default=1) @@ -54,7 +54,7 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() hf_config = AutoConfig.from_pretrained(args.model_path) - + num_layers=hf_config.num_hidden_layers if hasattr(hf_config, 'num_key_value_heads'): num_kv_heads=hf_config.num_key_value_heads @@ -65,7 +65,7 @@ def parse_args() -> argparse.Namespace: head_size=(hf_config.head_dim if hasattr(hf_config, 'head_dim') else hf_config.hidden_size//hf_config.num_attention_heads) use_mla=hf_config.architectures[0].startswith("Deepseek") - + # TODO: different model config may have different attribute name model_config = ModelConfig( num_layers=num_layers, @@ -76,14 +76,13 @@ def parse_args() -> argparse.Namespace: dp_size=args.dp_size, dtype=hf_config.torch_dtype ) - + cache_config = CacheConfig( enable_cpu=args.enable_cpu, enable_ssd=args.enable_ssd, enable_remote=args.enable_remote, use_gds=False, enable_trace=False, - use_pinned_memory=False, ssd_cache_iouring_entries=512, tokens_per_block=args.block_size, num_cpu_blocks=args.num_cpu_blocks, @@ -93,6 +92,6 @@ def parse_args() -> argparse.Namespace: remote_cache_size_mode=args.remote_cache_size_mode, remote_cache_path=args.remote_cache_path, ) - + kvserver = KVServer(model_config, cache_config, args.server_recv_port) - kvserver.run() \ No newline at end of file + kvserver.run() diff --git a/examples/scheduler_server_example.py b/examples/scheduler_server_example.py index 1aae7ec298..059cc467aa 100644 --- a/examples/scheduler_server_example.py +++ b/examples/scheduler_server_example.py @@ -16,9 +16,9 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, model_config, gpu_kv_layout): """Run TP client process""" from flexkv.server.client import KVTPClient - + print(f"Starting TP client: dp_client_id={dp_client_id}, tp_rank={tp_rank}, device_id={device_id}") - + try: # Set CUDA device for this process if torch.cuda.is_available(): @@ -27,8 +27,8 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo torch.cuda.init() # Clear cache torch.cuda.empty_cache() - - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Create GPU blocks for this TP client gpu_blocks = [] @@ -51,7 +51,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo # Keep TP client running while True: time.sleep(1) - + except Exception as e: print(f"TP client {tp_rank} error: {e}") import traceback @@ -84,7 +84,6 @@ def main(): enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks, ) @@ -106,14 +105,14 @@ def main(): cache_config=cache_config, server_recv_port="ipc:///tmp/scheduler_server_example" # TPClient connects to this port ) - + # Start background server thread to handle TPClient registration scheduler_server.start_server_thread() - - print(f"SchedulerServer started!") + + print("SchedulerServer started!") print(f"TPClient can connect to: {scheduler_server.get_server_port()}") print("Starting TP client processes...") - + # Start TP client processes tp_client_processes = [] for tp_rank in range(tp_size): @@ -123,7 +122,7 @@ def main(): if device_id >= available_gpus: device_id = device_id % available_gpus print(f"Warning: Using GPU {device_id} for TP rank {tp_rank} (not enough GPUs)") - + tp_client_process = Process( target=run_tp_client_process, args=(0, tp_rank, device_id, scheduler_server.get_server_port(), model_config, gpu_kv_layout), @@ -134,32 +133,32 @@ def main(): print(f"Started TP client process for rank {tp_rank} on device {device_id}") print("Waiting for all TP clients to register...") - + time.sleep(5) - + # Now we can directly use scheduler_server without network communication # Example: Create some test data (following benchmark_kvmanager.py pattern) batch_size = 4 seq_len = 128 - + print("\n=== Generating test data ===") # Generate separate sequences for each request (correct approach) batch_token_ids = [] batch_slot_mappings = [] batch_token_masks = [] - + for i in range(batch_size): # Each sequence is independent (seq_len,) shape token_ids = torch.randint(0, 1000, (seq_len,)) slot_mapping = torch.arange(i * seq_len, (i + 1) * seq_len) token_mask = torch.ones(seq_len, dtype=torch.bool) - + batch_token_ids.append(token_ids) batch_slot_mappings.append(slot_mapping) batch_token_masks.append(token_mask) - + print(f"Generated {batch_size} sequences, each with {seq_len} tokens") - + print("\n=== Executing PUT Operations ===") # PUT operations - each sequence processed separately start_time = time.time() @@ -173,7 +172,7 @@ def main(): if task_id: put_task_ids.append(task_id) print(f"PUT task {task_id} created for sequence {i}") - + put_time = (time.time() - start_time) * 1000 print(f"Created {len(put_task_ids)} PUT tasks, time: {put_time:.2f}ms") time.sleep(2) @@ -190,10 +189,10 @@ def main(): if task_id: get_task_ids.append(task_id) print(f"GET task {task_id} created for sequence {i}") - + get_time = (time.time() - start_time) * 1000 print(f"Created {len(get_task_ids)} GET tasks, time: {get_time:.2f}ms") - + print("\n=== Waiting for All Tasks to Complete ===") # Wait for all tasks to complete - can wait for multiple tasks at once all_task_ids = put_task_ids + get_task_ids @@ -202,7 +201,7 @@ def main(): masks = scheduler_server.wait(all_task_ids) wait_time = (time.time() - start_time) * 1000 print(f"All {len(all_task_ids)} tasks completed, time: {wait_time:.2f}ms") - + # Analyze results if masks: total_tokens = 0 @@ -211,7 +210,7 @@ def main(): tokens = mask.sum().item() if hasattr(mask, 'sum') else len(mask) total_tokens += tokens print(f"Task {task_id}: {tokens} tokens processed") - + print("\n=== Trying Non-blocking Wait ===") # Create a few more tasks and try non-blocking wait extra_task_ids = [] @@ -223,7 +222,7 @@ def main(): ) if task_id: extra_task_ids.append(task_id) - + if extra_task_ids: # Immediately try to wait (might not be completed yet) masks = scheduler_server.try_wait(extra_task_ids) @@ -233,15 +232,15 @@ def main(): print(f"Tasks {extra_task_ids} not ready yet, will wait...") masks = scheduler_server.wait(extra_task_ids) print(f"Tasks {extra_task_ids} completed after wait") - + print("\n✅ All operations completed successfully!") - - + + # Clean up resources print("\n=== Shutting down SchedulerServer ===") scheduler_server.shutdown() print("SchedulerServer has been shut down") - + # Terminate TP client processes print("Terminating TP client processes...") for i, process in enumerate(tp_client_processes): @@ -253,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch new file mode 100644 index 0000000000..812a1d6e2f --- /dev/null +++ b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch @@ -0,0 +1,453 @@ +diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py +new file mode 100644 +index 000000000..a57328ffd +--- /dev/null ++++ b/examples/offline_inference/prefix_caching_flexkv.py +@@ -0,0 +1,162 @@ ++# SPDX-License-Identifier: Apache-2.0 ++import os ++import time ++import json ++ ++from vllm import LLM, SamplingParams ++from vllm.distributed import cleanup_dist_env_and_memory ++ ++# NOTE: This is just a running example. For benchmarking purpose, ++# please see benchmarks/benchmark_prefix_caching.py ++ ++ ++flexkv_config = { ++ "server_recv_port": "ipc:///tmp/flexkv_test", ++ "cache_config": { ++ "enable_cpu": True, ++ "num_cpu_blocks": 10240, ++ }, ++ "num_log_interval_requests": 200 ++} ++flexkv_config_path = "./flexkv_config.json" ++with open(flexkv_config_path, 'w') as f: ++ json.dump(flexkv_config, f) ++os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path ++ ++ ++# Common prefix. ++prefix = ( ++ "You are an expert school principal, skilled in effectively managing " ++ "faculty and staff. Draft 10-15 questions for a potential first grade " ++ "Head Teacher for my K-12, all-girls', independent school that emphasizes " ++ "community, joyful discovery, and life-long learning. The candidate is " ++ "coming in for a first-round panel interview for a 8th grade Math " ++ "teaching role. They have 5 years of previous teaching experience " ++ "as an assistant teacher at a co-ed, public school with experience " ++ "in middle school math teaching. Based on these information, fulfill " ++ "the following paragraph: ") ++ ++# Sample prompts. ++prompts = [ ++ "Hello, my name is", ++ "The president of the United States is", ++ "The capital of France is", ++ "The future of AI is", ++] ++ ++generating_prompts = [prefix + prompt for prompt in prompts] ++ ++# Create a sampling params object. ++sampling_params = SamplingParams(temperature=0.0) ++ ++kv_transfer_config = { ++ "kv_connector": "FlexKVConnectorV1", ++ "kv_role": "kv_both", ++} ++# model_path = "/data0/models/facebook/opt-125m" ++model_path = "/data0/models/Qwen3/Qwen3-32B" ++tp_size = 8 ++gpu_memory_utilization = 0.4 ++ ++ ++ ++def main(): ++ # Create an LLM without prefix caching as a baseline. ++ regular_llm = LLM(model=model_path, ++ enable_prefix_caching=False, ++ gpu_memory_utilization=gpu_memory_utilization, ++ tensor_parallel_size=tp_size ++ ) ++ ++ print("Results without `enable_prefix_caching`") ++ ++ # ruff: noqa: E501 ++ # Generate texts from the prompts. The output is a list of RequestOutput objects ++ # that contain the prompt, generated text, and other information. ++ outputs = regular_llm.generate(generating_prompts, sampling_params) ++ ++ regular_generated_texts = [] ++ # Print the outputs. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ regular_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Destroy the LLM object and free up the GPU memory. ++ del regular_llm ++ cleanup_dist_env_and_memory() ++ ++ # return ++ ++ # Create an LLM with prefix caching enabled. ++ prefix_cached_llm = LLM(model=model_path, ++ enable_prefix_caching=True, ++ gpu_memory_utilization=gpu_memory_utilization, ++ tensor_parallel_size=tp_size, ++ kv_transfer_config=kv_transfer_config, ++ ) ++ ++ # Warmup so that the shared prompt's KV cache is computed. ++ prefix_cached_llm.generate(generating_prompts[0], sampling_params) ++ ++ # wait for offload kv task finished. ++ time.sleep(2) ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `enable_prefix_caching`") ++ ++ cached_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ cached_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == cached_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ # wait for offload kv task finished. ++ time.sleep(2) ++ ++ # reset prefix cache to use flexkv ++ prefix_cached_llm.reset_prefix_cache() ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `flexkv`") ++ ++ flexkv_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ flexkv_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == flexkv_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ ++ ++if __name__ == "__main__": ++ main() ++ # pass +\ No newline at end of file +diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py +index 584fc1d65..db1cfe36b 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/factory.py ++++ b/vllm/distributed/kv_transfer/kv_connector/factory.py +@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") ++ ++KVConnectorFactory.register_connector( ++ "FlexKVConnectorV1", ++ "vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector", ++ "FlexKVConnectorV1") +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py +new file mode 100644 +index 000000000..bdfa9f321 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py +@@ -0,0 +1,191 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++from typing import TYPE_CHECKING, Any, Optional ++ ++import torch ++from flexkv.integration.vllm.vllm_v1_adapter import FlexKVConnectorV1Impl ++ ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.logger import init_logger ++from vllm.v1.core.sched.output import SchedulerOutput ++from vllm.v1.outputs import KVConnectorOutput ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.core.kv_cache_manager import KVCacheBlocks ++ from vllm.v1.request import Request ++ ++logger = init_logger(__name__) ++ ++ ++class FlexKVConnectorV1(KVConnectorBase_V1): ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ super().__init__(vllm_config=vllm_config, role=role) ++ self._flexkv_connector = FlexKVConnectorV1Impl(vllm_config, role) ++ ++ def shutdown(self): ++ self._flexkv_connector.shutdown() ++ ++ # ============================== ++ # Worker-side methods ++ # ============================== ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """ ++ Start loading the KV cache from the connector to vLLM's paged ++ KV buffer. This is called from the forward context before the ++ forward pass to enable async loading during model execution. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ ++ """ ++ self._flexkv_connector.start_load_kv(forward_context, **kwargs) ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """ ++ Block until the KV for a specific layer is loaded into vLLM's ++ paged buffer. This is called from within attention layer to ensure ++ async copying from start_load_kv is complete. ++ ++ This interface will be useful for layer-by-layer pipelining. ++ ++ Args: ++ layer_name: the name of that layer ++ """ ++ self._flexkv_connector.wait_for_layer_load(layer_name) ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """ ++ Start saving the a layer of KV cache from vLLM's paged buffer ++ to the connector. This is called from within attention layer to ++ enable async copying during execution. ++ ++ Args: ++ layer_name (str): the name of the layer. ++ kv_layer (torch.Tensor): the paged KV buffer of the current ++ layer in vLLM. ++ attn_metadata (AttentionMetadata): the attention metadata. ++ **kwargs: additional arguments for the save operation. ++ """ ++ self._flexkv_connector.save_kv_layer(layer_name, kv_layer, attn_metadata, ++ **kwargs) ++ ++ def wait_for_save(self): ++ """ ++ Block until all the save operations is done. This is called ++ as the forward context exits to ensure that the async saving ++ from save_kv_layer is complete before finishing the forward. ++ ++ This prevents overwrites of paged KV buffer before saving done. ++ """ ++ self._flexkv_connector.wait_for_save() ++ ++ def get_finished( ++ self, finished_req_ids: set[str] ++ ) -> tuple[Optional[set[str]], Optional[set[str]]]: ++ """ ++ Notifies worker-side connector ids of requests that have ++ finished generating tokens. ++ ++ Returns: ++ ids of requests that have finished asynchronous transfer ++ (requests that previously returned True from request_finished()), ++ tuple of (sending/saving ids, recving/loading ids). ++ The finished saves/sends req ids must belong to a set provided in a ++ call to this method (this call or a prior one). ++ """ ++ return self._flexkv_connector.get_finished(finished_req_ids) ++ ++ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ++ """ ++ Initialize with the KV caches. Useful for pre-registering the ++ KV Caches in the KVConnector (e.g. for NIXL). ++ ++ Args: kv_caches: ++ dictionary of layer names, kv cache ++ """ ++ self._flexkv_connector.register_kv_caches(kv_caches) ++ ++ # ============================== ++ # Scheduler-side methods ++ # ============================== ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> tuple[int, bool]: ++ """ ++ Get number of new tokens that can be loaded from the ++ external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ return self._flexkv_connector.get_num_new_matched_tokens( ++ request, num_computed_tokens) ++ ++ def update_state_after_alloc(self, request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int): ++ """ ++ Update KVConnector state after block allocation. ++ """ ++ self._flexkv_connector.update_state_after_alloc(request, blocks, ++ num_external_tokens) ++ ++ def build_connector_meta( ++ self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: ++ """ ++ Build the connector metadata for this step. ++ ++ This function should NOT modify fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ return self._flexkv_connector.build_connector_meta(scheduler_output) ++ ++ def update_connector_output(self, connector_output: KVConnectorOutput): ++ """ ++ Update KVConnector state from worker-side connectors output. ++ ++ Args: ++ connector_output (KVConnectorOutput): the worker-side ++ connectors output. ++ """ ++ self._flexkv_connector.update_connector_output(connector_output) ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ """ ++ Called when a request has finished, before its blocks are freed. ++ ++ Returns: ++ True if the request is being saved/sent asynchronously and blocks ++ should not be freed until the request_id is returned from ++ get_finished(). ++ Optional KVTransferParams to be included in the request outputs ++ returned by the engine. ++ """ ++ return self._flexkv_connector.request_finished(request, block_ids) +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 981023409..a6c8fac38 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -118,6 +118,7 @@ class Scheduler(SchedulerInterface): + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() ++ self.sending_kv_reqs: dict[str, Request] = {} + + # Encoder-related. + # Calculate encoder cache size if applicable +@@ -1029,7 +1030,8 @@ class Scheduler(SchedulerInterface): + + if not delay_free_blocks: + self._free_blocks(request) +- ++ else: ++ self.sending_kv_reqs[request.request_id] = request + return kv_xfer_params + + def _free_blocks(self, request: Request): +@@ -1041,7 +1043,7 @@ class Scheduler(SchedulerInterface): + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: +- return len(self.finished_req_ids) > 0 ++ return len(self.finished_req_ids) > 0 or len(self.sending_kv_reqs) > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() +@@ -1082,6 +1084,8 @@ class Scheduler(SchedulerInterface): + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() ++ if self.connector and hasattr(self.connector, "shutdown"): ++ self.connector.shutdown() + + ######################################################################## + # KV Connector Related Methods +@@ -1149,6 +1153,10 @@ class Scheduler(SchedulerInterface): + scheduler the request during the next step. + """ + ++ # avoid busy checking ++ if len(self.running) == 0: ++ time.sleep(0.01) ++ + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + +@@ -1158,4 +1166,5 @@ class Scheduler(SchedulerInterface): + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (kv_connector_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) ++ del self.sending_kv_reqs[req_id] + self._free_blocks(self.requests[req_id]) +diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py +index a03ebe35d..8e4460957 100644 +--- a/vllm/v1/worker/kv_connector_model_runner_mixin.py ++++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py +@@ -66,9 +66,9 @@ class KVConnectorModelRunnerMixin: + scheduler_output, wait_for_save=False) as kv_connector_output: + pass + +- if (not kv_connector_output.finished_sending +- and not kv_connector_output.finished_recving): +- return EMPTY_MODEL_RUNNER_OUTPUT ++ # if (not kv_connector_output.finished_sending ++ # and not kv_connector_output.finished_recving): ++ # return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output diff --git a/examples/vllm_adaption/flexkv_vllm_0_10_0.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_10_0.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch diff --git a/examples/vllm_adaption/flexkv_vllm_0_8_4.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_8_4.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 669163cc83..e113aeb69b 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -18,27 +18,134 @@ from functools import partial from queue import Queue from typing import List, Tuple, Optional, Dict, Callable +from dataclasses import dataclass, field +import numpy as np +import nvtx import torch +from flexkv.c_ext import CRadixNode, CRadixTreeIndex, CMatchResult from flexkv.cache.mempool import Mempool from flexkv.cache.radixtree import RadixTreeIndex, RadixNode, MatchResult -from flexkv.cache.transfer_pattern import ( - convert_read_graph_to_layer_wise_graph, add_virtal_op_for_mutiple_finished_ops -) +from flexkv.cache.transfer_pattern import add_virtal_op_for_mutiple_finished_ops from flexkv.common.block import SequenceMeta from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.exceptions import InvalidConfigError, NotEnoughSpaceError from flexkv.common.transfer import ( - DeviceType, TransferOpGraph, TransferOp, TransferType, TransferDescriptor + DeviceType, TransferOpGraph, TransferOp, TransferType ) +from flexkv.common.debug import flexkv_logger +@dataclass +class MatchResultAccel: + num_ready_matched_blocks: int = 0 + num_matched_blocks: int = 0 + last_ready_node: Optional['CRadixNode'] = None + last_node: Optional['CRadixNode'] = None + last_node_matched_length: int = 0 + physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + + def __post_init__(self) -> None: + assert self.physical_blocks.ndim == 1 + +class CacheEngineAccel: + def __init__(self, + device_type: DeviceType, + num_total_blocks: int, + tokens_per_block: int, + evict_ratio: float): + if not isinstance(device_type, DeviceType): + raise InvalidConfigError(f"Unknown device type: {device_type}") + if num_total_blocks <= 0: + raise InvalidConfigError(f"Invalid num_total_blocks: {num_total_blocks}") + if tokens_per_block <= 0 or (tokens_per_block & (tokens_per_block - 1)) != 0: + raise InvalidConfigError(f"Invalid tokens_per_block: {tokens_per_block}, " + f"tokens_per_block must be a power of 2") + + self.device_type = device_type + + self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks) + + self.mempool = Mempool(num_total_blocks=num_total_blocks) + + self.tokens_per_block = tokens_per_block + self.num_total_blocks = num_total_blocks + self.evict_ratio = evict_ratio + + def reset(self) -> None: + self.index.reset() + self.mempool.reset() + + def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: + sequence_meta.gen_hashes() + match_result = self.index.match_prefix(torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, True) + return MatchResultAccel(match_result.num_ready_matched_blocks, match_result.num_matched_blocks, + match_result.last_ready_node, match_result.last_node, + match_result.last_node_matched_length, + torch.tensor(match_result.physical_blocks, dtype=torch.int64).numpy()) + def insert(self, + sequence_meta: SequenceMeta, + physical_block_ids: torch.Tensor, + num_insert_blocks: int = -1, + is_ready: bool = True, + match_result: Optional[MatchResultAccel] = None) -> Optional[CRadixNode]: + sequence_meta.gen_hashes() + if match_result is None: + return self.index.insert(torch.from_numpy(physical_block_ids).to(torch.int64), + torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, + num_insert_blocks, + is_ready) + else: + return self.index.insert(torch.from_numpy(physical_block_ids).to(torch.int64), + torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, + num_insert_blocks, + is_ready, + match_result.last_node, + match_result.num_matched_blocks, + match_result.last_node_matched_length) + + def lock_node(self, node: CRadixNode) -> None: + self.index.lock(node) + + def cleanup(self, node: CRadixNode, cleanup_length: int) -> None: + self.index.unlock(node) + self.index.set_ready(node, True, cleanup_length) + + def take(self, + num_required_blocks: int, + protected_node: Optional[CRadixNode] = None, + strict: bool = True) -> torch.Tensor: + if num_required_blocks > self.mempool.num_free_blocks: + if protected_node is not None: + self.index.lock(protected_node) + evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) + num_evicted = self.index.evict(target_blocks, evict_block_num) + if num_evicted != evict_block_num: + target_blocks.resize_(num_evicted) + self.mempool.recycle_blocks(target_blocks.numpy()) + + if protected_node is not None: + self.index.unlock(protected_node) + if strict and num_required_blocks > self.mempool.num_free_blocks: + raise NotEnoughSpaceError("Not enough free blocks to take, ", + required=num_required_blocks, + available=self.mempool.num_free_blocks) + num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) + return self.mempool.allocate_blocks(num_allocated_blocks) + + def recycle(self, physical_blocks: np.ndarray) -> None: + self.mempool.recycle_blocks(physical_blocks) class CacheEngine: def __init__(self, device_type: DeviceType, num_total_blocks: int, - tokens_per_block: int): + tokens_per_block: int, + evict_ratio: float): if not isinstance(device_type, DeviceType): raise InvalidConfigError(f"Unknown device type: {device_type}") if num_total_blocks <= 0: @@ -55,6 +162,7 @@ def __init__(self, self.tokens_per_block = tokens_per_block self.num_total_blocks = num_total_blocks + self.evict_ratio = evict_ratio def reset(self) -> None: self.index.reset() @@ -67,7 +175,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResult: def insert(self, sequence_meta: SequenceMeta, - physical_block_ids: torch.Tensor, + physical_block_ids: np.ndarray, num_insert_blocks: int = -1, is_ready: bool = True, match_result: Optional[MatchResult] = None) -> Optional[RadixNode]: @@ -87,12 +195,13 @@ def cleanup(self, node: RadixNode, cleanup_length: int) -> None: def take(self, num_required_blocks: int, protected_node: Optional[RadixNode] = None, - strict: bool = True) -> torch.Tensor: + strict: bool = True) -> np.ndarray: if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) + evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) self.mempool.recycle_blocks( - self.index.evict(num_required_blocks - self.mempool.num_free_blocks) + self.index.evict(evict_block_num) ) if protected_node is not None: self.index.unlock(protected_node) @@ -103,7 +212,7 @@ def take(self, num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) return self.mempool.allocate_blocks(num_allocated_blocks) - def recycle(self, physical_blocks: torch.Tensor) -> None: + def recycle(self, physical_blocks: np.ndarray) -> None: self.mempool.recycle_blocks(physical_blocks) class GlobalCacheEngine: @@ -119,19 +228,40 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): self.cache_engines = {} if cache_config.enable_cpu: - self.cpu_cache_engine = CacheEngine(DeviceType.CPU, + if cache_config.index_accel: + self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) + else: + self.cpu_cache_engine = CacheEngine(DeviceType.CPU, + cache_config.num_cpu_blocks, + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: - self.ssd_cache_engine = CacheEngine(DeviceType.SSD, + if cache_config.index_accel: + self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, + cache_config.num_ssd_blocks, + cache_config.tokens_per_block, + cache_config.evict_ratio) + else: + self.ssd_cache_engine = CacheEngine(DeviceType.SSD, cache_config.num_ssd_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.SSD] = self.ssd_cache_engine if cache_config.enable_remote: - self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, + if cache_config.index_accel: + self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, + cache_config.num_remote_blocks, + cache_config.tokens_per_block, + cache_config.evict_ratio) + else: + self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, cache_config.num_remote_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \ @@ -149,25 +279,34 @@ def reset(self) -> None: def get(self, request_id: int, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor, + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, torch.Tensor, Callable, List[int]]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: self._check_input(token_ids, token_mask, slot_mapping) + if layer_num == -1: layer_num = self.model_config.num_layers if layer_granularity == -1: layer_granularity = layer_num + if layer_num != layer_granularity: + flexkv_logger.error(f"Layerwise transfer is not supported yet, " + f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") + raise NotImplementedError(f"Layerwise transfer is not supported yet, " + f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") + # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] + block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block - gpu_block_mapping = self._slot_to_block_mapping(slot_mapping)[:block_end_idx-block_start_idx] + gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, + self.tokens_per_block)[:block_end_idx-block_start_idx] sequence_meta = SequenceMeta(token_ids=aligned_token_ids, tokens_per_block=self.cache_config.tokens_per_block) @@ -179,7 +318,7 @@ def get(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) else: @@ -189,24 +328,24 @@ def get(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) - transfer_graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops( + transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, finished_ops_ids ) - return_mask = torch.zeros_like(token_mask) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[block_start_idx* self.tokens_per_block: (block_start_idx + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True - if layer_num // layer_granularity != 1: - transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, - finished_ops_ids=finished_ops_ids, - layer_num=layer_num, - layer_granularity=layer_granularity) + # if layer_num // layer_granularity != 1: + # transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, + # finished_ops_ids=finished_ops_ids, + # layer_num=layer_num, + # layer_granularity=layer_granularity) transfer_graph.bind_to_dp_group(dp_id) for device_type in node_to_unlock: @@ -216,14 +355,14 @@ def get(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, finished_ops_ids + return transfer_graph, return_mask, callback, task_end_op_id def _get_impl_global(self, request_id: int, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: """ transfer pattern: @@ -252,7 +391,7 @@ def _get_impl_global(self, #early return if no blocks to transfer if fragment123_num_blocks == 0: return self._empty_get_return(request_id) - assert fragment123_num_blocks <= len(gpu_block_mapping) + assert fragment123_num_blocks <= len(gpu_block_ids) transfer_graph = TransferOpGraph() finished_ops_ids = [] @@ -263,7 +402,7 @@ def _get_impl_global(self, fragment3_num_blocks = max(len(remote_matched_blocks) - fragment12_num_blocks, 0) fragment23_num_blocks = fragment2_num_blocks + fragment3_num_blocks - fragment123_gpu_blocks = gpu_block_mapping[:fragment123_num_blocks] + fragment123_gpu_blocks = gpu_block_ids[:fragment123_num_blocks] fragment123_cpu_blocks = cpu_matched_blocks fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] fragment3_remote_blocks = remote_matched_blocks[-fragment3_num_blocks:] @@ -271,7 +410,7 @@ def _get_impl_global(self, cpu_node_to_unlock = cpu_matched_result.last_ready_node ssd_node_to_unlock = ssd_matched_result.last_ready_node remote_node_to_unlock = remote_matched_result.last_ready_node - cpu_blocks_to_free = torch.tensor([], dtype=torch.int64) + cpu_blocks_to_free = np.array([], dtype=np.int64) if fragment23_num_blocks > 0: num_extra_required_blocks = fragment23_num_blocks @@ -283,7 +422,7 @@ def _get_impl_global(self, if len(fragment23_cpu_blocks) < num_extra_required_blocks: self.cpu_cache_engine.recycle(fragment23_cpu_blocks) return self._empty_get_return(request_id) - fragment123_cpu_blocks = torch.cat([fragment123_cpu_blocks, fragment23_cpu_blocks]) + fragment123_cpu_blocks = np.concatenate([fragment123_cpu_blocks, fragment23_cpu_blocks]) # we only insert the buffer blocks to cpu cache engine only: # 1. the cpu cache engine satisfies prefix cache after insertion # 2. the sequence is all ready blocks @@ -302,14 +441,8 @@ def _get_impl_global(self, op_disk2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[fragment1_num_blocks:fragment12_num_blocks] - ), + src_block_ids = fragment2_ssd_blocks, + dst_block_ids = fragment123_cpu_blocks[fragment1_num_blocks:fragment12_num_blocks], layer_id = 0, layer_granularity = layer_num ) @@ -320,14 +453,8 @@ def _get_impl_global(self, op_remote2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.REMOTE2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=fragment3_remote_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[-fragment3_num_blocks:], - ), + src_block_ids = fragment3_remote_blocks, + dst_block_ids = fragment123_cpu_blocks[-fragment3_num_blocks:], layer_id = 0, layer_granularity = layer_num ) @@ -353,14 +480,8 @@ def _get_impl_global(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[-fragment3_num_blocks:], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment3_ssd_blocks, - ), + src_block_ids = fragment123_cpu_blocks[-fragment3_num_blocks:], + dst_block_ids = fragment3_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -377,15 +498,8 @@ def _get_impl_global(self, op_h2d = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment123_gpu_blocks, - device_id = 0 - ), + src_block_ids = fragment123_cpu_blocks, + dst_block_ids = fragment123_gpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -414,7 +528,7 @@ def _get_impl_local(self, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: """ transfer pattern: @@ -429,7 +543,10 @@ def _get_impl_local(self, assert self.cache_config.enable_cpu assert self.cpu_cache_engine is not None - cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result = self.match_local_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) # tailor the blocks to assure: # the blocks are needed by the mask & the blocks are ready @@ -445,12 +562,12 @@ def _get_impl_local(self, #early return if no blocks to transfer if fragment12_num_blocks == 0: return self._empty_get_return(request_id) - assert fragment12_num_blocks <= len(gpu_block_mapping) + assert fragment12_num_blocks <= len(gpu_block_ids) transfer_graph = TransferOpGraph() finished_ops_ids = [] - fragment12_gpu_blocks = gpu_block_mapping[:fragment12_num_blocks] + fragment12_gpu_blocks = gpu_block_ids[:fragment12_num_blocks] fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] fragment1_cpu_blocks = cpu_matched_blocks[:fragment1_num_blocks] @@ -458,7 +575,9 @@ def _get_impl_local(self, ssd_node_to_unlock = ssd_matched_result.last_ready_node # prepare cpu blocks to transfer - cpu_blocks_to_free = torch.tensor([], dtype=torch.int64) + cpu_blocks_to_free = np.array([], dtype=np.int64) + op_disk2h = None + fragment2_cpu_blocks = None if fragment2_num_blocks > 0: fragment2_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment2_num_blocks, @@ -474,39 +593,12 @@ def _get_impl_local(self, op_disk2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), + src_block_ids = fragment2_ssd_blocks, + dst_block_ids = fragment2_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) transfer_graph.add_transfer_op(op_disk2h) - - op_h2d_frag2 = TransferOp( - graph_id = transfer_graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks[-fragment2_num_blocks:], - device_id = 0 - ), - layer_id = 0, - layer_granularity = layer_num - ) - transfer_graph.add_transfer_op(op_h2d_frag2) - - transfer_graph.add_dependency(op_h2d_frag2.op_id, op_disk2h.op_id) - finished_ops_ids.append(op_h2d_frag2.op_id) - # we only insert the buffer blocks to cpu cache engine only: # 1. the cpu cache engine satisfies prefix cache after insertion # 2. the sequence is all ready blocks @@ -520,23 +612,22 @@ def _get_impl_local(self, match_result=cpu_matched_result) else: cpu_blocks_to_free = fragment2_cpu_blocks - op_h2d_frag1 = TransferOp( + if fragment2_cpu_blocks is not None: + fragment12_cpu_blocks = np.concatenate([fragment1_cpu_blocks, fragment2_cpu_blocks]) + else: + fragment12_cpu_blocks = fragment1_cpu_blocks + op_h2d = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment1_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks[:fragment1_num_blocks], - device_id = 0 - ), + src_block_ids = fragment12_cpu_blocks, + dst_block_ids = fragment12_gpu_blocks, layer_id = 0, layer_granularity = layer_num ) - transfer_graph.add_transfer_op(op_h2d_frag1) - finished_ops_ids.append(op_h2d_frag1.op_id) + transfer_graph.add_transfer_op(op_h2d) + if op_disk2h is not None: + transfer_graph.add_dependency(op_h2d.op_id, op_disk2h.op_id) + finished_ops_ids.append(op_h2d.op_id) node_to_unlock = {} if cpu_node_to_unlock is not None: @@ -549,11 +640,11 @@ def _get_impl_local(self, def put(self, request_id: int, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor, + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, torch.Tensor, Callable, List[int]]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -561,13 +652,14 @@ def put(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) # the mask should has a prefix of True assert block_start_idx == 0 - gpu_block_mapping = self._slot_to_block_mapping(slot_mapping)[:block_end_idx-block_start_idx] + gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, + self.tokens_per_block)[:block_end_idx-block_start_idx] sequence_meta = SequenceMeta(token_ids=aligned_token_ids, tokens_per_block=self.cache_config.tokens_per_block) @@ -580,7 +672,7 @@ def put(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) else: @@ -591,16 +683,16 @@ def put(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) - transfer_graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops( + transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, finished_ops_ids ) - return_mask = torch.zeros_like(token_mask) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[(block_start_idx + skipped_gpu_blocks)* self.tokens_per_block: (block_start_idx + skipped_gpu_blocks + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True transfer_graph.bind_to_dp_group(dp_id) @@ -612,14 +704,14 @@ def put(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, finished_ops_ids + return transfer_graph, return_mask, callback, task_end_op_id def _put_impl_global(self, request_id: int, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: """ transfer pattern: @@ -639,7 +731,10 @@ def _put_impl_global(self, assert self.cpu_cache_engine is not None assert self.remote_cache_engine is not None - cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all(sequence_meta) cpu_matched_blocks = cpu_matched_result.physical_blocks[ :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ @@ -648,15 +743,15 @@ def _put_impl_global(self, :remote_matched_result.num_matched_blocks][block_mask_start:block_mask_end] num_skipped_blocks = len(cpu_matched_blocks) - fragment12_num_blocks = len(gpu_block_mapping) - num_skipped_blocks + fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks if fragment12_num_blocks == 0: return self._empty_put_return(request_id) - fragment2_num_blocks = len(gpu_block_mapping) - len(ssd_matched_blocks) + fragment2_num_blocks = len(gpu_block_ids) - len(ssd_matched_blocks) if not self.cache_config.enable_ssd: fragment2_num_blocks = 0 - fragment3_num_blocks = len(gpu_block_mapping) - len(remote_matched_blocks) + fragment3_num_blocks = len(gpu_block_ids) - len(remote_matched_blocks) - fragment12_gpu_blocks = gpu_block_mapping[num_skipped_blocks:] + fragment12_gpu_blocks = gpu_block_ids[num_skipped_blocks:] fragment12_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment12_num_blocks, @@ -678,7 +773,7 @@ def _put_impl_global(self, else: self.ssd_cache_engine.recycle(fragment2_ssd_blocks) else: - fragment2_ssd_blocks = torch.tensor([], dtype=torch.int64) + fragment2_ssd_blocks = np.array([], dtype=np.int64) put_to_remote = False if fragment3_num_blocks > 0: fragment3_remote_blocks = self.remote_cache_engine.take( @@ -691,7 +786,7 @@ def _put_impl_global(self, else: self.remote_cache_engine.recycle(fragment3_remote_blocks) else: - fragment3_remote_blocks = torch.tensor([], dtype=torch.int64) + fragment3_remote_blocks = np.array([], dtype=np.int64) transfer_graph = TransferOpGraph() finished_ops_ids = [] @@ -699,14 +794,8 @@ def _put_impl_global(self, op_d2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment12_cpu_blocks, - ), + src_block_ids = fragment12_gpu_blocks, + dst_block_ids = fragment12_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -718,14 +807,8 @@ def _put_impl_global(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), + src_block_ids = fragment2_cpu_blocks, + dst_block_ids = fragment2_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -736,21 +819,15 @@ def _put_impl_global(self, if put_to_remote: if fragment3_num_blocks > fragment12_num_blocks: extra_num_cpu_blocks = fragment3_num_blocks - fragment12_num_blocks - fragment3_cpu_blocks = torch.cat([fragment12_cpu_blocks, + fragment3_cpu_blocks = np.concatenate([fragment12_cpu_blocks, cpu_matched_blocks[-extra_num_cpu_blocks:]]) else: fragment3_cpu_blocks = fragment12_cpu_blocks[-fragment3_num_blocks:] op_h2remote = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2REMOTE, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment3_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=fragment3_remote_blocks, - ), + src_block_ids = fragment3_cpu_blocks, + dst_block_ids = fragment3_remote_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -789,7 +866,7 @@ def _put_impl_local(self, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: """ transfer pattern: @@ -805,21 +882,24 @@ def _put_impl_local(self, assert self.cpu_cache_engine is not None # assert self.ssd_cache_engine is not None - cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result = self.match_local_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) cpu_matched_blocks = cpu_matched_result.physical_blocks[ :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ :ssd_matched_result.num_matched_blocks][block_mask_start:block_mask_end] num_skipped_blocks = len(cpu_matched_blocks) - fragment12_num_blocks = len(gpu_block_mapping) - num_skipped_blocks + fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks if fragment12_num_blocks == 0: return self._empty_put_return(request_id) - fragment2_num_blocks = len(gpu_block_mapping) - len(ssd_matched_blocks) + fragment2_num_blocks = len(gpu_block_ids) - len(ssd_matched_blocks) if not self.cache_config.enable_ssd: fragment2_num_blocks = 0 - fragment12_gpu_blocks = gpu_block_mapping[num_skipped_blocks:] + fragment12_gpu_blocks = gpu_block_ids[num_skipped_blocks:] fragment12_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment12_num_blocks, @@ -833,7 +913,7 @@ def _put_impl_local(self, strict=False ) else: - fragment2_ssd_blocks = torch.tensor([], dtype=torch.int64) + fragment2_ssd_blocks = np.array([], dtype=np.int64) if len(fragment12_cpu_blocks) < fragment12_num_blocks or \ len(fragment2_ssd_blocks) < fragment2_num_blocks: self.cpu_cache_engine.recycle(fragment12_cpu_blocks) @@ -847,14 +927,8 @@ def _put_impl_local(self, op_d2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment12_cpu_blocks, - ), + src_block_ids = fragment12_gpu_blocks, + dst_block_ids = fragment12_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -866,14 +940,8 @@ def _put_impl_local(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), + src_block_ids = fragment2_cpu_blocks, + dst_block_ids = fragment2_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -903,7 +971,7 @@ def _put_impl_local(self, def _transfer_callback(self, node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]], - buffer_to_free: Optional[Dict[DeviceType, torch.Tensor]] = None) -> None: + buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None) -> None: if DeviceType.CPU in node_to_unlock: assert self.cpu_cache_engine is not None self.cpu_cache_engine.cleanup(node_to_unlock[DeviceType.CPU][0], node_to_unlock[DeviceType.CPU][1]) @@ -924,6 +992,18 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + @nvtx.annotate("Match Prefix Accel", color="yellow") + def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]: + cpu_matched_result = MatchResultAccel() + ssd_matched_result = MatchResultAccel() + if self.cpu_cache_engine: + cpu_matched_result = self.cpu_cache_engine.match(sequence_meta) + if self.ssd_cache_engine: + ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) + + return cpu_matched_result, ssd_matched_result + + @nvtx.annotate("Match Prefix", color="yellow") def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() @@ -933,7 +1013,24 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result + + @nvtx.annotate("Match All Prefix accel", color="yellow") + def match_all_accel(self, + sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: + cpu_matched_result = MatchResultAccel() + ssd_matched_result = MatchResultAccel() + remote_matched_result = MatchResultAccel() + # TODO: avoid redundant match? + if self.cpu_cache_engine: + cpu_matched_result = self.cpu_cache_engine.match(sequence_meta) + if self.ssd_cache_engine: + ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) + if self.remote_cache_engine: + remote_matched_result = self.remote_cache_engine.match(sequence_meta) + return cpu_matched_result, ssd_matched_result, remote_matched_result + + @nvtx.annotate("Match All Prefix", color="yellow") def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() @@ -949,25 +1046,29 @@ def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResu return cpu_matched_result, ssd_matched_result, remote_matched_result def _check_input(self, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor) -> None: + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray) -> None: + assert token_ids.dtype == np.int64 + # assert token_mask.dtype == np.bool_, f"token_mask.dtype={token_mask.dtype}" + assert slot_mapping.dtype == np.int64 assert token_ids.ndim == 1 assert token_mask.ndim == 1 assert slot_mapping.ndim == 1 - assert len(token_ids) == len(token_mask), f"len(token_ids)={len(token_ids)}, len(token_mask)={len(token_mask)}" - assert len(slot_mapping) == token_mask.sum().item(), f"len(slot_mapping)={len(slot_mapping)}, token_mask.sum().item()={token_mask.sum().item()}" + assert token_ids.size == token_mask.size, f"token_ids.size={token_ids.size}, token_mask.size={token_mask.size}" + assert slot_mapping.size == token_mask.sum(), \ + f"slot_mapping.size={slot_mapping.size}, token_mask.sum()={token_mask.sum()}" - def _slot_to_block_mapping(self, - slot_mapping: torch.Tensor) -> torch.Tensor: - block_mapping: torch.Tensor = slot_mapping[::self.tokens_per_block] // self.tokens_per_block - return block_mapping + @staticmethod + def slot_mapping_to_block_ids(slot_mapping: np.ndarray, tokens_per_block: int) -> np.ndarray: + block_ids: np.ndarray = slot_mapping[::tokens_per_block] // tokens_per_block + return block_ids def _get_block_range(self, - token_mask: torch.Tensor) -> Tuple[int, int]: - mask_idx = torch.where(token_mask)[0] + token_mask: np.ndarray) -> Tuple[int, int]: + mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: - return 0, 0 - start_idx = int(mask_idx[0].item() // self.tokens_per_block) - end_idx = int(mask_idx[-1].item() // self.tokens_per_block) + return len(token_mask)//self.tokens_per_block, len(token_mask)//self.tokens_per_block + start_idx = mask_idx[0].item() // self.tokens_per_block + end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 diff --git a/flexkv/cache/mempool.py b/flexkv/cache/mempool.py index d0e6848a30..d00077181f 100644 --- a/flexkv/cache/mempool.py +++ b/flexkv/cache/mempool.py @@ -1,7 +1,7 @@ from collections import deque from typing import List -import torch +import numpy as np from flexkv.common.exceptions import NotEnoughSpaceError @@ -14,18 +14,18 @@ def __init__( assert num_total_blocks > 0 self.num_total_blocks = num_total_blocks - self._free_mask = torch.ones(self.num_total_blocks, dtype=torch.bool) + self._free_mask = np.ones(self.num_total_blocks, dtype=np.bool_) self._num_free = num_total_blocks - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 def reset(self) -> None: - self._free_mask.fill_(True) + self._free_mask.fill(True) self._num_free = self.num_total_blocks - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 - def allocate_blocks(self, num: int) -> torch.Tensor: + def allocate_blocks(self, num: int) -> np.ndarray: if num < 0: raise ValueError(f"num must be greater than 0, but got {num}") if num > self._num_free: @@ -41,8 +41,8 @@ def allocate_blocks(self, num: int) -> torch.Tensor: self._num_free -= num return free_ids - def recycle_blocks(self, block_ids: torch.Tensor) -> None: - if block_ids.ndim != 1 or block_ids.dtype != torch.int64: + def recycle_blocks(self, block_ids: np.ndarray) -> None: + if block_ids.ndim != 1 or block_ids.dtype != np.int64: raise ValueError("block_ids must be a 1D tensor of int64") if self._free_mask[block_ids].any(): free_ids = block_ids[self._free_mask[block_ids]] @@ -51,7 +51,7 @@ def recycle_blocks(self, block_ids: torch.Tensor) -> None: self._num_free += len(block_ids) def _update_free_ids(self) -> None: - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 @property diff --git a/flexkv/cache/radixtree.py b/flexkv/cache/radixtree.py index 818e778a63..a9f69c6a35 100644 --- a/flexkv/cache/radixtree.py +++ b/flexkv/cache/radixtree.py @@ -32,11 +32,14 @@ class MatchResult: last_ready_node: Optional['RadixNode'] = None last_node: Optional['RadixNode'] = None last_node_matched_length: int = 0 - physical_blocks: torch.Tensor = torch.empty(0, dtype=torch.int64) + physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) def __post_init__(self) -> None: assert self.physical_blocks.ndim == 1 - assert self.physical_blocks.dtype == torch.int64 + assert self.physical_blocks.dtype == np.int64 + + def is_empty(self) -> bool: + return self.num_matched_blocks == 0 @dataclass class RadixNode: @@ -192,7 +195,7 @@ def match_prefix(self, last_ready_node=last_ready_node, last_node=current_node, last_node_matched_length=last_node_matched_length, - physical_blocks=torch.from_numpy(physical_blocks).to(torch.int64)) + physical_blocks=physical_blocks) def num_matched_blocks(self, sequence: SequenceMeta) -> int: @@ -201,7 +204,7 @@ def num_matched_blocks(self, def insert(self, sequence_meta: SequenceMeta, - physical_block_ids: torch.Tensor, + physical_block_ids: np.ndarray, num_insert_blocks: int = -1, is_ready: bool = True, match_result: Optional[MatchResult] = None) -> Optional[RadixNode]: @@ -210,7 +213,7 @@ def insert(self, assert 0 <= num_insert_blocks <= sequence_meta.num_blocks assert physical_block_ids.ndim == 1 - assert physical_block_ids.dtype == torch.int64 + assert physical_block_ids.dtype == np.int64 sequence_meta.gen_hashes() if match_result is None: @@ -232,7 +235,7 @@ def insert(self, new_node = RadixNode( block_hashes=sequence_meta.block_hashes[num_matched_blocks:num_insert_blocks], - physical_blocks=physical_block_ids.numpy(), + physical_blocks=physical_block_ids, is_ready=is_ready, lock_cnt=0, last_access_time=time.time() @@ -255,7 +258,7 @@ def insert(self, return new_node - def evict(self, num_evicted: int) -> torch.Tensor: + def evict(self, num_evicted: int) -> np.ndarray: candidates = [] for node in self.leaf_nodes.values(): if node.evictable(): @@ -277,7 +280,7 @@ def evict(self, num_evicted: int) -> torch.Tensor: physical_blocks = node.physical_blocks node.parent = None evicted_blocks = np.concatenate([evicted_blocks, physical_blocks]) - return torch.from_numpy(evicted_blocks).to(torch.int64) + return evicted_blocks def lock(self, node: RadixNode) -> None: if node.lock_cnt < 0: @@ -340,22 +343,22 @@ def total_unready_blocks(self) -> int: index = RadixTreeIndex(tokens_per_block=tokens_per_block) print(f"init index, tokens_per_block = {tokens_per_block}") - token_ids1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) - token_ids2 = torch.tensor([1, 2, 3, 4, 15, 16, 17, 18]) + token_ids1 = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64) + token_ids2 = np.array([1, 2, 3, 4, 15, 16, 17, 18], dtype=np.int64) seq1 = SequenceMeta(token_ids=token_ids1, tokens_per_block=tokens_per_block) seq2 = SequenceMeta(token_ids=token_ids2, tokens_per_block=tokens_per_block) - index.insert(seq1, torch.tensor([0, 1, 2, 3], dtype=torch.int64), is_ready=True) + index.insert(seq1, np.array([0, 1, 2, 3], dtype=np.int64), is_ready=True) print(f"insert seq1 = {seq1.token_ids}, " f"total cached blocks = {index.total_cached_blocks()}") seq2_matched_blocks = index.num_matched_blocks(seq2) assert seq2_matched_blocks == 2 - index.insert(seq2, torch.tensor([8, 9], dtype=torch.int64), is_ready=True) + index.insert(seq2, np.array([8, 9], dtype=np.int64), is_ready=True) print(f"insert seq2 = {seq2.token_ids}, " f"total cached blocks = {index.total_cached_blocks()}") - seq3 = SequenceMeta(token_ids=torch.tensor([1,2,3,4,0,0]), + seq3 = SequenceMeta(token_ids=np.array([1,2,3,4,0,0], dtype=np.int64), tokens_per_block=tokens_per_block) match_result = index.num_matched_blocks(seq3) print(f"match {seq3.token_ids}, num cached blocks: {match_result}") diff --git a/flexkv/cache/transfer_pattern.py b/flexkv/cache/transfer_pattern.py index 1246d248e8..fcf207408b 100644 --- a/flexkv/cache/transfer_pattern.py +++ b/flexkv/cache/transfer_pattern.py @@ -1,371 +1,33 @@ from typing import List, Optional, Tuple +import numpy as np import torch -from flexkv.common.transfer import DeviceType, TransferType -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferDescriptor +from flexkv.common.transfer import TransferType +from flexkv.common.transfer import TransferOp, TransferOpGraph def add_virtal_op_for_mutiple_finished_ops( graph: TransferOpGraph, finished_ops_ids: List[int] -)->Tuple[TransferOpGraph, List[int]]: - if len(finished_ops_ids) <= 1: - return graph, finished_ops_ids +)->Tuple[TransferOpGraph, int]: + if len(finished_ops_ids) == 0: + return graph, -1 + elif len(finished_ops_ids) == 1: + return graph, finished_ops_ids[0] else: op = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.VIRTUAL, + src_block_ids = np.array([], dtype=np.int64), + dst_block_ids = np.array([], dtype=np.int64), layer_id = -1, layer_granularity = -1, ) graph.add_transfer_op(op) for op_id in finished_ops_ids: graph.add_dependency(op.op_id, op_id) - return graph, [op.op_id] - -def create_read_graph_cpu_storage( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - graph: Optional[TransferOpGraph] = None, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a read transfer graph with (REMOTE_STORAGE / SSD)->CPU->GPU operations - ssd_blocks: the blocks of ssd that are used as a lower-level storage backend, - including ssd or remote storage. This can be empty, which means cpu-only kvcache. - Returns: - graph: TransferOpGraph - ops_to_be_tracked: List[int]: a list of transfer ops that can indicate - the completion of some key operations - """ - assert len(gpu_blocks) == len(cpu_blocks) - if graph is None: - graph = TransferOpGraph() - assert len(gpu_blocks) > 0 - if len(ssd_blocks) == 0: - op = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num, - ) - graph.add_transfer_op(op) - return graph, [op.op_id] - elif len(ssd_blocks) < len(cpu_blocks): - task_end_ops_ids = [] - if len(ssd_blocks) > 0: - op1 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):] - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op1) - op2 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):] - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[-len(ssd_blocks):], - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op2) - graph.add_dependency(op2.op_id, op1.op_id) - task_end_ops_ids.append(op2.op_id) - op3 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[:len(cpu_blocks) - len(ssd_blocks)] - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op3) - task_end_ops_ids.append(op3.op_id) - return graph, task_end_ops_ids - else: - op1 = TransferOp( - graph_id = graph.graph_id, - transfer_type=TransferType.DISK2H, - src_descriptor=TransferDescriptor( - device_type=DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - dst_descriptor=TransferDescriptor( - device_type=DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op1) - op2 = TransferOp( - graph_id = graph.graph_id, - transfer_type=TransferType.H2D, - src_descriptor=TransferDescriptor( - device_type=DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - dst_descriptor=TransferDescriptor( - device_type=DeviceType.GPU, - physical_block_ids=gpu_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op2) - graph.add_dependency(op2.op_id, op1.op_id) - return graph, [op2.op_id] - -def create_read_graph_cpu_ssd_remote( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - remote_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - write_back_to_ssd: bool = True, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a read transfer graph with (REMOTE_STORAGE + SSD)->CPU->GPU operations - Returns: - graph: TransferOpGraph - finished_ops_ids: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - graph = TransferOpGraph() - finished_ops_ids: List[int] = [] - if len(remote_blocks) == 0: - graph, finished_ops_ids = create_read_graph_cpu_storage(gpu_blocks=gpu_blocks, - cpu_blocks=cpu_blocks, - ssd_blocks=ssd_blocks, - gpu_device_id=gpu_device_id, - layer_num=layer_num, - graph=graph) - if len(finished_ops_ids) > 0: - graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops(graph, finished_ops_ids) - assert len(finished_ops_ids) > 0 - return graph, finished_ops_ids - else: - if len(remote_blocks) < len(gpu_blocks): - graph, finished_ops_ids = create_read_graph_cpu_storage(gpu_blocks=gpu_blocks[:-len(remote_blocks)], - cpu_blocks=cpu_blocks[:-len(remote_blocks)], - ssd_blocks=ssd_blocks[:-len(remote_blocks)], - gpu_device_id=gpu_device_id, - layer_num=layer_num, - graph=graph) - op_r2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.REMOTE2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=remote_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_r2h) - op_h2d = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[-len(remote_blocks):], - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2d) - graph.add_dependency(op_h2d.op_id, op_r2h.op_id) - if write_back_to_ssd: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks[-len(remote_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_r2h.op_id) - finished_ops_ids.append(op_h2d.op_id) - if len(finished_ops_ids) > 0: - graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops(graph, finished_ops_ids) - return graph, finished_ops_ids - -def create_write_graph_cpu_storage( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - graph: Optional[TransferOpGraph] = None, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a write transfer graph with CPU->REMOTE_STORAGE / SSD operations - ssd_blocks: the blocks of ssd that are used as a lower-level storage backend, - including ssd or remote storage. This can be empty, which means cpu-only kvcache. - Write op granularity is larger: gpu->cpu is put into the same op. - Returns: - graph: TransferOpGraph - layer_wise_ops: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - if graph is None: - graph = TransferOpGraph() - op_d2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_d2h) - if len(ssd_blocks) == 0: - return graph, [op_d2h.op_id] - else: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_d2h.op_id) - return graph, [op_d2h.op_id] - -def create_write_graph_cpu_ssd_remote( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - remote_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a write transfer graph with CPU->REMOTE_STORAGE + SSD operations - Returns: - graph: TransferOpGraph - layer_wise_ops: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - graph = TransferOpGraph() - op_d2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_d2h) - if len(ssd_blocks) != 0: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_d2h.op_id) - if len(remote_blocks) != 0: - op_h2remote = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2REMOTE, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=remote_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2remote) - graph.add_dependency(op_h2remote.op_id, op_d2h.op_id) - return graph, [op_d2h.op_id] + return graph, op.op_id def convert_read_graph_to_layer_wise_graph( transfer_graph: TransferOpGraph, @@ -394,8 +56,8 @@ def convert_read_graph_to_layer_wise_graph( new_op = TransferOp( graph_id=new_graph.graph_id, transfer_type=op.transfer_type, - src_descriptor=op.src_descriptor, - dst_descriptor=op.dst_descriptor, + src_block_ids=op.src_block_ids, + dst_block_ids=op.dst_block_ids, layer_id=i * layer_granularity, layer_granularity=layer_granularity, # Inherit these fields directly diff --git a/flexkv/common/block.py b/flexkv/common/block.py index d1ab24687a..7e2cc5d8b2 100644 --- a/flexkv/common/block.py +++ b/flexkv/common/block.py @@ -5,13 +5,13 @@ import numpy as np import torch -from flexkv.common.hash_utils import HashType, gen_hashes, get_hash_size, hash_tensor +from flexkv.common.hash_utils import HashType, gen_hashes, get_hash_size, hash_array @dataclass class SequenceMeta: - token_ids: torch.Tensor + token_ids: np.ndarray tokens_per_block: int @@ -19,7 +19,7 @@ class SequenceMeta: _has_hashes: bool = False - def __init__(self, token_ids: torch.Tensor, tokens_per_block: int): + def __init__(self, token_ids: np.ndarray, tokens_per_block: int): self.token_ids = token_ids self.tokens_per_block = tokens_per_block @@ -44,13 +44,13 @@ def get_hash(self, block_id: int) -> Optional[HashType]: if self._has_hashes: return HashType(int(self.block_hashes[block_id].item())) else: - return hash_tensor(self.token_ids[:(block_id+1)*self.tokens_per_block]) + return hash_array(self.token_ids[:(block_id+1)*self.tokens_per_block]) def gen_hashes(self) -> None: if self._has_hashes: return assert self.token_ids.ndim == 1 - self.block_hashes = gen_hashes(self.token_ids, self.tokens_per_block).numpy() + self.block_hashes = gen_hashes(self.token_ids, self.tokens_per_block) assert self.block_hashes.ndim == 1 assert self.block_hashes.size == self.num_blocks assert self.block_hashes.itemsize == get_hash_size() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index a022292eb8..2805ae606c 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -14,6 +14,7 @@ class ModelConfig: head_size: int use_mla: bool = False dtype: torch.dtype = torch.bfloat16 + max_req_tokens = 163840 # parallel configs tp_size: int = 1 @@ -31,13 +32,13 @@ class CacheConfig: enable_ssd: bool = False enable_remote: bool = False use_gds: bool = False - use_pinned_memory: bool = False + index_accel: bool = False # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE + cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE # mempool capacity configs num_cpu_blocks: int = 1000000 @@ -70,3 +71,17 @@ class CacheConfig: trace_max_file_size_mb: int = 100 trace_max_files: int = 5 trace_flush_interval_ms: int = 1000 + + #evict ratio + evict_ratio: float = 0.0 + + def __post_init__(self): + layout_fields = ['gpu_kv_layout_type', + 'cpu_kv_layout_type', + 'ssd_kv_layout_type', + 'remote_kv_layout_type'] + for field in layout_fields: + value = getattr(self, field) + if isinstance(value, str): + setattr(self, field, KVCacheLayoutType[value.upper()]) + diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index 6c52931db6..a522c5549a 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -6,18 +6,28 @@ from typing import Optional, Callable, Any +FLEXKV_LOGGING_PREFIX = os.getenv("FLEXKV_LOGGING_PREFIX", "FLEXKV") +_FORMAT = (f"[{FLEXKV_LOGGING_PREFIX}] %(levelname)s %(asctime)s.%(msecs)03d " + " %(message)s") +_DATE_FORMAT = "%m-%d %H:%M:%S" + class FlexkvLogger: def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(message)s" + has_console_handler = any( + isinstance(handler, logging.StreamHandler) + for handler in self.logger.handlers ) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - self.logger.addHandler(console_handler) + if not has_console_handler: + formatter = logging.Formatter( + fmt=_FORMAT, + datefmt=_DATE_FORMAT, + ) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) self.set_level(debug_level) diff --git a/flexkv/common/hash_utils.py b/flexkv/common/hash_utils.py index 6f8ec9fc96..32f0055f48 100644 --- a/flexkv/common/hash_utils.py +++ b/flexkv/common/hash_utils.py @@ -1,6 +1,7 @@ import time from typing import NewType, Optional +import numpy as np import torch from flexkv import c_ext @@ -18,33 +19,37 @@ def __init__(self) -> None: def reset(self) -> None: self.hasher.reset() - def update(self, tensor: torch.Tensor) -> None: - self.hasher.update(tensor) + def update(self, array: np.ndarray) -> None: + self.hasher.update(torch.from_numpy(array)) def digest(self) -> HashType: return HashType(self.hasher.digest()) -def hash_tensor(tensor: torch.Tensor) -> HashType: - hasher = Hasher() - hasher.update(tensor) - return HashType(hasher.digest()) +_HASHER = Hasher() -def gen_hashes(token_ids: torch.Tensor, tokens_per_block: int, hasher: Optional[Hasher] = None) -> torch.Tensor: - block_hashes = torch.zeros(token_ids.numel() // tokens_per_block, dtype=torch.uint64) +def hash_array(array: np.ndarray) -> HashType: + _HASHER.reset() + _HASHER.update(array) + return HashType(_HASHER.digest()) + +def gen_hashes(token_ids: np.ndarray, tokens_per_block: int, hasher: Optional[Hasher] = None) -> np.ndarray: + block_hashes = np.zeros(token_ids.size // tokens_per_block, dtype=np.uint64) if hasher is None: hasher = Hasher() - c_ext.gen_hashes(hasher.hasher, token_ids, tokens_per_block, block_hashes) + c_ext.gen_hashes(hasher.hasher, torch.from_numpy(token_ids), tokens_per_block, torch.from_numpy(block_hashes)) return block_hashes if __name__ == "__main__": - torch.manual_seed(0) - token_ids = torch.randint(0, 10000, (32000, ), dtype=torch.int64) + np.random.seed(0) + token_ids = np.random.randint(0, 10000, (1000, ), dtype=np.int64) print(f"token ids length: {token_ids.shape[0]}") + result = hash_array(token_ids) start = time.time() - result = hash_tensor(token_ids) - end = time.time() - print(f"tensor hash: {result}, time: {end - start}s") - start = time.time() - result2 = gen_hashes(token_ids, 16) + for i in range(1): + result = hash_array(token_ids) end = time.time() - print(f"block hashes: {result2}, time: {end - start}s") + print(f"array hash: {result}, average time: {(end - start)*1000/5}ms") + # start = time.time() + # result2 = gen_hashes(token_ids, 16) + # end = time.time() + # print(f"block hashes: {result2}, time: {(end - start)*1000}ms") diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index d1e2838a87..5013308b40 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -3,7 +3,7 @@ import os import pickle import time -from typing import List, Callable, Any, Optional, Tuple, Union +from typing import Callable, Any, Optional, Tuple, Union from dataclasses import dataclass import torch @@ -16,7 +16,7 @@ @dataclass class TensorSharedHandle: rebuild_func: Callable - rebuild_args: List[Any] + rebuild_args: Tuple[Any] device: torch.device def __init__(self, tensor: torch.Tensor): @@ -29,7 +29,7 @@ def get_tensor(self) -> torch.Tensor: return tensor @staticmethod - def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, List[Any], torch.device]: + def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, Tuple[Any], torch.device]: device = tensor.device rebuild_func, rebuild_args = reductions.reduce_tensor(tensor) @@ -37,7 +37,7 @@ def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, List[Any], to return rebuild_func, rebuild_args, device @staticmethod - def _import_tensor_handle(rebuild_func: Callable, rebuild_args: List[Any], device: torch.device) -> torch.Tensor: + def _import_tensor_handle(rebuild_func: Callable, rebuild_args: Tuple[Any], device: torch.device) -> torch.Tensor: try: tensor = rebuild_func(*rebuild_args) diff --git a/flexkv/common/request.py b/flexkv/common/request.py index 1c871ea821..ef1c6ca009 100644 --- a/flexkv/common/request.py +++ b/flexkv/common/request.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from enum import Enum +from typing import Callable, List, Optional import torch +import numpy as np class KVRequestType(Enum): @@ -18,3 +20,21 @@ class KVRequest: slot_mapping: torch.Tensor layer_granularity: int = -1 dp_id: int = 0 + + +class KVResponseStatus(Enum): + SUCCESS = "success" + NOTFOUND = "not_found" + UNREADY = "unready" + TIMEOUT = "timeout" + CANCELLED = "cancelled" + FAILED = "failed" + +@dataclass +class KVResponse: + status: KVResponseStatus + task_id: int + return_mask: Optional[np.ndarray] + + def get_mask(self) -> torch.Tensor: + return torch.from_numpy(self.return_mask) if self.return_mask is not None else None diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py new file mode 100644 index 0000000000..3f67d612ba --- /dev/null +++ b/flexkv/common/ring_buffer.py @@ -0,0 +1,109 @@ +import torch +import threading +import time +import random + +from collections import OrderedDict,deque +import numpy as np +from flexkv.common.transfer import TransferOp +from flexkv.common.debug import flexkv_logger +from flexkv.common.hash_utils import hash_array + + +class SharedOpPool: + def __init__(self, max_op_num: int, max_block_num: int, dtype = np.int64): + self.max_op_num = max_op_num + self.max_block_num = max_block_num + self.dtype = dtype + # create the buffer tensor + self.buffer_o = torch.empty((self.max_op_num, self.max_block_num), dtype = torch.int64) + # move tensor to share memory + self.buffer = self.buffer_o.share_memory_() + + flexkv_logger.info(f"[SharedOpPool] block ids buffer data_ptr: {self.buffer.storage().data_ptr()}") + + self.free_slots = deque(range(max_op_num)) + self.slot_map = dict() # {slot_hash: slot_id} + + self.slot_ref_count = np.zeros(max_op_num, dtype=np.int32) + self.slot_hashes = [0]*max_op_num + + self.lock = threading.Lock() + + def allocate_slot(self, block_ids: np.ndarray): + """ + Allocating a slot for the given block ids + Params: + block_ids: the block ids of src address or dst address + Returns: + slot_id: the slot which is assigned to the given block ids, -1 if failed + """ + # firstly, determine whether the length of block ids exceeds the limit + num_blocks = block_ids.size + if num_blocks > self.max_block_num or num_blocks == 0: + return -1 + + slot_hash = hash_array(block_ids) + reuse = False + + # get the slot of empty buffer + with self.lock: + if slot_hash in self.slot_map: + slot_id = self.slot_map[slot_hash] + reuse = True + else: + if not self.free_slots: + flexkv_logger.info("No empty slot in SharedOpPool") + return -1 + + slot_id = self.free_slots.popleft() + self.slot_map[slot_hash] = slot_id + + # update status managers + self.slot_ref_count[slot_id] += 1 + self.slot_hashes[slot_id] = slot_hash + + # do copy + if not reuse: + self.buffer[slot_id, :num_blocks] = torch.from_numpy(block_ids).to(torch.int64) + + return slot_id + + def free_slot(self, slot_id: int): + """ + Free the relevant resources of corresponding op, called when op transfer completed. + Input: + op_id: the index of current op + Output: + None + """ + with self.lock: + slot_hash = self.slot_hashes[slot_id] + if slot_hash not in self.slot_map: + raise RuntimeError(f"Slot {slot_id} is not in use, double free detected!") + self.slot_ref_count[slot_id] -= 1 + assert self.slot_ref_count[slot_id] >= 0, f"Slot {slot_id} ref count is negative" + if self.slot_ref_count[slot_id] == 0: + self.free_slots.append(slot_id) + del self.slot_map[slot_hash] + + def get_buffer(self): + return self.buffer + + def get_buffer_size(self): + return self.max_op_num, self.max_block_num + + def status(self): + """ + Current status logger + """ + with self.lock: + used = len(self.slot_map) + free = self.max_op_num - used + return {"used_slots": used, + "free_slots": free, + "capacity": self.max_op_num} + + +if __name__ == "__main__": + manager = SharedOpPool(4, 10) diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 45178e85b4..dff6b1ff3a 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -121,7 +121,6 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_kv_layout_type": str(cache_config.ssd_kv_layout_type), "remote_kv_layout_type": str(cache_config.remote_kv_layout_type), "use_gds": cache_config.use_gds, - "use_pinned_memory": cache_config.use_pinned_memory, "remote_cache_size_mode": cache_config.remote_cache_size_mode, "num_cpu_blocks": cache_config.num_cpu_blocks, "num_ssd_blocks": cache_config.num_ssd_blocks, @@ -134,6 +133,7 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_cache_iouring_flags": cache_config.ssd_cache_iouring_flags, "remote_cache_path": cache_config.remote_cache_path, "remote_config_custom": cache_config.remote_config_custom, + "evict_ratio": cache_config.evict_ratio, } # Convert gpu_layout to dict if provided diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 342f0678c0..91229b7834 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -3,7 +3,7 @@ from enum import Enum from typing import ClassVar, List, Set, Dict -import torch +import numpy as np class DeviceType(Enum): @@ -31,16 +31,6 @@ class PartitionBlockType(Enum): ROUND_ROBIN = 0 SEQUENTIAL = 1 -@dataclass -class TransferDescriptor: - device_type: DeviceType = DeviceType.CPU - device_id: int = 0 - physical_block_ids: torch.Tensor = torch.tensor([], dtype=torch.int64) - - def __post_init__(self) -> None: - assert self.physical_block_ids.ndim == 1 - assert self.physical_block_ids.dtype == torch.int64 - class TransferOpStatus(Enum): PENDING = 0 RUNNING = 1 @@ -54,24 +44,31 @@ class TransferOp: op_id: int = field(init=False) graph_id: int transfer_type: TransferType - layer_id: int - layer_granularity: int - src_descriptor: TransferDescriptor = field(default_factory=TransferDescriptor) - dst_descriptor: TransferDescriptor = field(default_factory=TransferDescriptor) + src_block_ids: np.ndarray + dst_block_ids: np.ndarray + layer_id: int = 0 + layer_granularity: int = -1 # this will change dynamically as transfer ops executed predecessors: Set[int] = field(default_factory=set) # this will keep the full info successors: Set[int] = field(default_factory=set) status: TransferOpStatus = TransferOpStatus.PENDING dp_id: int = 0 + # used for get block ids inner worker process + src_slot_id: int = -1 + dst_slot_id: int = -1 + valid_block_num: int = 0 def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ - len(self.src_descriptor.physical_block_ids) != len(self.dst_descriptor.physical_block_ids): - raise ValueError("src_descriptor and dst_descriptor must have the same number of physical blocks") + self.src_block_ids.size != self.dst_block_ids.size: + raise ValueError("src_block_ids and dst_block_ids must have the same number of physical blocks") with TransferOp._lock: self.op_id = TransferOp._next_op_id TransferOp._next_op_id += 1 + assert self.src_block_ids.dtype == np.int64 + assert self.dst_block_ids.dtype == np.int64 + self.valid_block_num = self.src_block_ids.size class TransferOpGraph: @@ -79,13 +76,14 @@ class TransferOpGraph: _lock = threading.Lock() def __init__(self) -> None: - self.graph_id = self._get_next_graph_id() + self.graph_id = self._get_graph_id() self._op_map: Dict[int, TransferOp] = {} self._ready_ops: Set[int] = set() self._trigger_ops: Set[int] = set() + self._gpu_transfer_op_id: int = -1 @classmethod - def _get_next_graph_id(cls) -> int: + def _get_graph_id(cls) -> int: with cls._lock: graph_id = cls._next_graph_id cls._next_graph_id += 1 @@ -115,6 +113,14 @@ def trigger_op(self, op_id: int) -> None: def add_transfer_op(self, op: TransferOp) -> None: op.graph_id = self.graph_id self._op_map[op.op_id] = op + if op.transfer_type == TransferType.H2D or \ + op.transfer_type == TransferType.D2H or \ + op.transfer_type == TransferType.D2DISK or \ + op.transfer_type == TransferType.DISK2D: + if self._gpu_transfer_op_id == -1: + self._gpu_transfer_op_id = op.op_id + else: + raise ValueError("Only one GPU transfer op is allowed") self._ready_ops.add(op.op_id) def add_dependency(self, successor_op_id: int, predecessor_op_id: int) -> None: @@ -130,8 +136,6 @@ def mark_completed(self, op_id: int) -> None: assert self._op_map[op_id].status == TransferOpStatus.RUNNING self._op_map[op_id].status = TransferOpStatus.COMPLETED my_successors = self._op_map[op_id].successors - if len(my_successors) == 0: - return for successor_id in my_successors: self._op_map[successor_id].predecessors.remove(op_id) @@ -164,6 +168,16 @@ def all_transfer_ops_completed(self) -> bool: return all(op.status == TransferOpStatus.COMPLETED for op in self._op_map.values()) + def set_gpu_blocks(self, gpu_blocks: np.ndarray) -> None: + transfer_type = self._op_map[self._gpu_transfer_op_id].transfer_type + op = self._op_map[self._gpu_transfer_op_id] + if transfer_type.name.endswith("2D"): + op.dst_block_ids = gpu_blocks + else: + op.src_block_ids = gpu_blocks + assert op.src_block_ids.size == op.dst_block_ids.size, \ + f"src_block_ids.size={op.src_block_ids.size}, dst_block_ids.size={op.dst_block_ids.size}" + @property def num_ops(self) -> int: return len(self._op_map) @@ -172,69 +186,6 @@ def bind_to_dp_group(self, dp_id: int) -> None: for op in self._op_map.values(): op.dp_id = dp_id - def print_op_map(self) -> None: - """Print transfer op graph in a visual format showing dependencies. - - Example output: - Transfer Graph 5: - ├── Op 1 (H2D) [Completed] - │ └── No successors - ├── Op 2 (D2H) [Pending] - │ └── Followed by: 1 - └── Op 3 (DISK2H) [Pending] - └── Followed by: 1, 2 - """ - print(f"Transfer Graph {self.graph_id}:") - - # get all op ids and sort them - op_ids = sorted(self._op_map.keys()) - - for i, op_id in enumerate(op_ids): - op = self._op_map[op_id] - is_last = (i == len(op_ids) - 1) - - # draw the tree structure branch - prefix = "└── " if is_last else "├── " - - # get the op status - status = "[Completed]" if op.status == TransferOpStatus.COMPLETED else "[Pending]" - - # print the op info - print(f"{prefix}Op {op_id} ({op.transfer_type.name}) {status}") - - if op.transfer_type == TransferType.VIRTUAL: - continue - # print the dependency info - dep_prefix = " " if is_last else "│ " - if not op.successors: - print(f"{dep_prefix}└── No successors") - else: - deps_str = ", ".join(str(dep) for dep in sorted(op.successors)) - print(f"{dep_prefix}└── Followed by: {deps_str}") - - # print the transfer details - src_info = f"From: {op.src_descriptor.device_type.name}:{op.src_descriptor.device_id}" - dst_info = f"To: {op.dst_descriptor.device_type.name}:{op.dst_descriptor.device_id}" - print(f"{dep_prefix} └── {src_info} -> {dst_info}") - - print(f"{dep_prefix} └── layers: {op.layer_id} - {op.layer_id + op.layer_granularity}") - - # if there are physical block ids, also print them - if len(op.src_descriptor.physical_block_ids) > 0: - blocks = op.src_descriptor.physical_block_ids.tolist() - if len(blocks) > 3: - blocks_str = f"{blocks[:3]}... ({len(blocks)} blocks)" - else: - blocks_str = str(blocks) - print(f"{dep_prefix} └── Src Blocks: {blocks_str}") - if len(op.dst_descriptor.physical_block_ids) > 0: - blocks = op.dst_descriptor.physical_block_ids.tolist() - if len(blocks) > 3: - blocks_str = f"{blocks[:3]}... ({len(blocks)} blocks)" - else: - blocks_str = str(blocks) - print(f"{dep_prefix} └── Dst Blocks: {blocks_str}") - def get_nvtx_default_color() -> int: return 0xD3D3D3 diff --git a/flexkv/integration/__init__.py b/flexkv/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py new file mode 100644 index 0000000000..76f27f5b34 --- /dev/null +++ b/flexkv/integration/config.py @@ -0,0 +1,68 @@ + +import json +import os +import torch +import tempfile +from typing import TYPE_CHECKING +from dataclasses import dataclass, field + +from flexkv.common.debug import flexkv_logger + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig, FullAttentionSpec + from vllm.config import VllmConfig + + +logger = flexkv_logger + +@dataclass +class FlexKVConfig: + #base config + server_recv_port: str + + # cache config + cache_config: dict = field(default_factory=dict) + + # model config + block_size: int = None + num_layers: int = None + num_kv_heads: int = None + head_size: int = None + dtype: torch.dtype = None + use_mla: bool = False + tp_size: int = 1 + + # log config + num_log_interval_requests: int = 200 + + @classmethod + def from_env(cls) -> 'FlexKVConfig': + config_file_path = os.getenv('FLEXKV_CONFIG_PATH', None) + logger.info(f"{config_file_path=}") + if config_file_path is None: + return cls(enable_flexkv=False, + server_recv_port="") + + assert config_file_path.endswith(".json"), "flexkv config must be a json file." + + with open(config_file_path, 'r') as f: + config_dict: dict = json.load(f) + logger.info(f"FlexKV Config Dict: {config_dict}") + + return cls( + server_recv_port=config_dict.get("server_recv_port", f"ipc:///tmp/flexkv_test"), + cache_config=config_dict.get("cache_config", {}), + num_log_interval_requests=config_dict.get("num_log_interval_requests", 200), + ) + + def post_init_from_vllm_config( + self, + vllm_config: "VllmConfig", + ): + self.num_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) + self.block_size = vllm_config.cache_config.block_size + self.num_kv_heads = vllm_config.model_config.get_total_num_kv_heads() + self.head_size = vllm_config.model_config.get_head_size() + self.dtype = vllm_config.model_config.dtype + self.use_mla = vllm_config.model_config.is_deepseek_mla + self.tp_size = vllm_config.parallel_config.tensor_parallel_size \ No newline at end of file diff --git a/flexkv/integration/stats.py b/flexkv/integration/stats.py new file mode 100644 index 0000000000..3f4d1a70f9 --- /dev/null +++ b/flexkv/integration/stats.py @@ -0,0 +1,100 @@ +import time +from dataclasses import dataclass +from collections import deque + +from flexkv.common.debug import flexkv_logger + +logger = flexkv_logger + + +@dataclass +class FlexKVStats: + num_log_interval_requests: int + + # get info + num_get_requests: int = 0 + num_get_query_tokens: int = 0 + num_gpu_matched_tokens: int = 0 + num_flexkv_matched_tokens: int = 0 + + # put info + num_put_requests: int = 0 + num_put_query_tokens: int = 0 + num_put_unmatched_tokens: int = 0 + + num_failed_requests: int = 0 + + @property + def tatal_num_requests(self) -> int: + return self.num_get_requests + self.num_put_requests + + @property + def get_gpu_match_ratio(self) -> float: + if self.num_get_query_tokens == 0: + return 0.0 + return self.num_gpu_matched_tokens / self.num_get_query_tokens + + @property + def get_flexkv_match_ratio(self) -> float: + if self.num_get_query_tokens == 0: + return 0.0 + return self.num_flexkv_matched_tokens / self.num_get_query_tokens + + @property + def get_put_token_ratio(self) -> float: + if self.num_put_unmatched_tokens == 0: + return 0.0 + return self.num_flexkv_matched_tokens / self.num_put_unmatched_tokens + + def record_get( + self, + num_prompt_tokens: int, + num_gpu_matched_tokens: int, + num_flexkv_matched_tokens: int, + ): + self.num_get_requests += 1 + self.num_get_query_tokens += num_prompt_tokens + self.num_gpu_matched_tokens += num_gpu_matched_tokens + self.num_flexkv_matched_tokens += num_flexkv_matched_tokens + if self.num_get_requests == self.num_log_interval_requests: + self.log() + self.clear() + + def record_put( + self, + num_all_tokens: int, + num_unmatched_tokens: int, + ): + self.num_put_requests += 1 + self.num_put_query_tokens += num_all_tokens + self.num_put_unmatched_tokens += num_unmatched_tokens + + def record_faild( + self, + num_failed_requests: int + ): + self.num_failed_requests += num_failed_requests + + def clear(self): + self.num_get_requests = 0 + self.num_get_query_tokens = 0 + self.num_gpu_matched_tokens = 0 + self.num_flexkv_matched_tokens = 0 + self.num_put_requests = 0 + self.num_put_query_tokens = 0 + self.num_put_unmatched_tokens = 0 + self.num_failed_requests = 0 + + def log(self): + if self.num_put_unmatched_tokens == 0: + get_put_token_ratio_str = "Nan" + else: + get_put_token_ratio_str = f"{self.get_put_token_ratio*100:.2f}%" + logger.info( + f"[FlexKV] Metric of Recent {self.num_log_interval_requests} Requests: " + f"Num Failed Request: {self.num_failed_requests}, " + f"Num Get Query Tokens: {self.num_get_query_tokens}, " + f"GPU Hit Ratio: {self.get_gpu_match_ratio*100:.2f}%, " + f"FlexKV Hit Ratio: {self.get_flexkv_match_ratio*100:.2f}%, " + f"Get/Put Token Ratio: {get_put_token_ratio_str}.") + \ No newline at end of file diff --git a/flexkv/integration/utils.py b/flexkv/integration/utils.py new file mode 100644 index 0000000000..9b107b7c23 --- /dev/null +++ b/flexkv/integration/utils.py @@ -0,0 +1,5 @@ + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) \ No newline at end of file diff --git a/flexkv/integration/vllm/__init__.py b/flexkv/integration/vllm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py new file mode 100644 index 0000000000..c129e44563 --- /dev/null +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -0,0 +1,746 @@ +import os +import time +from typing import TYPE_CHECKING, Optional, Literal, Any +from dataclasses import dataclass, field +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from flexkv.kvmanager import KVManager +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ModelConfig, CacheConfig +from flexkv.common.request import KVResponseStatus +from flexkv.common.debug import flexkv_logger +from flexkv.integration.stats import FlexKVStats +from flexkv.integration.utils import cdiv +from flexkv.integration.config import FlexKVConfig + +# vllm +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, KVConnectorRole) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + from vllm.v1.outputs import KVConnectorOutput + + +logger = flexkv_logger + + +@dataclass +class FlexKVResponse: + task_id: int + task_type: Literal["get", "put"] + request: "Request" + success: bool + + +@dataclass +class FlexKVTask(ABC): + task_id: int = 0 + request: "Request" = 0 + + # slot mapping + slot_mapping: Optional[np.ndarray] = None + + # timer + match_start_time: float = 0 + match_end_time: float = 0 + task_launch_time: float = 0 + task_finished_time: float = 0 + + @property + def match_cost(self) -> float: + return (self.match_end_time - self.match_start_time) + + @property + def task_execute_cost(self) -> float: + return (self.task_finished_time - self.task_launch_time) + + @property + @abstractmethod + def task_type(self) -> str: + ... + + def __str__(self): + return (f"FlexKVTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVGetTask(FlexKVTask): + num_computed_tokens: int + num_new_matched_tokens: int + + @property + def task_type(self) -> str: + return "get" + + def __str__(self): + return (f"FlexKVGetTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"num_computed_tokens={self.num_computed_tokens}, " + f"num_new_matched_tokens={self.num_new_matched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVPutTask(FlexKVTask): + num_matched_tokens: int + num_unmatched_tokens: int + + @property + def task_type(self) -> str: + return "put" + + def __str__(self): + return (f"FlexKVPutTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"num_matched_tokens={self.num_matched_tokens}, " + f"num_unmatched_tokens={self.num_unmatched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +class FlexKVSchedulerConnector: + def __init__( + self, + flexkv_config: FlexKVConfig + ): + logger.info(f"Start init FlexKVSchedulerConnector with {flexkv_config}") + self.flexkv_config = flexkv_config + self.server_recv_port = flexkv_config.server_recv_port + self.tp_size = flexkv_config.tp_size + self.block_size = flexkv_config.block_size + self.model_config = ModelConfig( + num_layers=flexkv_config.num_layers, + num_kv_heads=flexkv_config.num_kv_heads, + head_size=flexkv_config.head_size, + use_mla=flexkv_config.use_mla, + dtype=flexkv_config.dtype, + tp_size=flexkv_config.tp_size, + ) + if "tokens_per_block" in flexkv_config.cache_config: + assert flexkv_config.cache_config.pop("tokens_per_block") == flexkv_config.block_size + self.cache_config = CacheConfig( + tokens_per_block=flexkv_config.block_size, + **flexkv_config.cache_config, + ) + self.flexkv_manager = KVManager(model_config=self.model_config, + cache_config=self.cache_config, + gpu_register_port=flexkv_config.server_recv_port) + self.flexkv_manager.start() + # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) + + # request_id -> task_id + self.req_id_to_task_dict: dict[str, int] = {} + # launched but unfinished tasks + self.get_tasks: dict[int, FlexKVGetTask] = {} + self.put_tasks: dict[int, FlexKVPutTask] = {} + # unlaunched tasks + self.tasks_to_launch: dict[int, FlexKVTask] = {} + self.tasks_to_cancel: dict[int, FlexKVTask] = {} + + self.flexkv_stats = FlexKVStats(flexkv_config.num_log_interval_requests) + + while not self.is_ready(): + logger.info("Waiting for flexkv init...") + time.sleep(5) + + logger.info("Finish init FlexKVSchedulerConnector") + + def is_ready( + self, + ) -> bool: + " Ask flexkv is ready " + return self.flexkv_manager.is_ready() + + def shutdown(self) -> None: + self.flexkv_manager.shutdown() + + @property + def dp_client_id(self) -> int: + return self.flexkv_manager.dp_client_id + + #################### + #### Get Method #### + #################### + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, bool]: A tuple containing two integer values representing the + number of new matched tokens and whether it is necessary + to get the new matched blocks from flexkv. + """ + task_id, num_new_matched_tokens = self._get_match(request=request, + num_computed_tokens=num_computed_tokens) + self.flexkv_stats.record_get(num_prompt_tokens=request.num_tokens, + num_gpu_matched_tokens=num_computed_tokens, + num_flexkv_matched_tokens=num_new_matched_tokens) + + if not self._need_to_get(num_prompt_tokens=request.num_tokens, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens): + return 0, False + + return num_new_matched_tokens, True + + + def _get_match( + self, + request: "Request", + num_computed_tokens: int = 0, + ) -> tuple[int, int]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, int]: A tuple containing two integer values representing + the task_id and number of new matched tokens. + """ + match_start_time = time.perf_counter() + num_tokens_to_get = (request.num_tokens//self.block_size)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_get] + + assert num_computed_tokens <= num_tokens_to_get, ( + f"{num_computed_tokens=} must less equal to {num_tokens_to_get=}") + assert num_computed_tokens % self.block_size == 0 + + if num_tokens_to_get == num_computed_tokens: + return -1, 0 + + np_token_ids = np.array(token_ids) + np_token_mask = np.ones_like(np_token_ids, dtype=bool) + np_token_mask[:num_computed_tokens] = False + task_id, matched_mask = self.flexkv_manager.get_match(token_ids=np_token_ids, + token_mask=np_token_mask) + num_new_matched_tokens = matched_mask.sum().item() + + # Auto cancel if not call update_state_after_alloc() + match_end_time = time.perf_counter() + logger.debug(f"Get match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + if num_new_matched_tokens > 0: + self.req_id_to_task_dict[request.request_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVGetTask(task_id=task_id, + request=request, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + + logger.debug(f"FlexKV create get task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_new_matched_tokens + + def _need_to_get( + self, + num_prompt_tokens: int, + num_computed_tokens: int, + num_new_matched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to get the new matched blocks from flexkv. + """ + return num_new_matched_tokens > 0 + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_new_matched_tokens: int, + ) -> None: + """ + Compute slot mapping and prepare to launch task. + Only call after get_num_new_matched_tokens(). + + Args: + request: Request to get. + blocks: All blocks of the request. + num_new_matched_tokens: Number of new matched tokens returned by + get_num_new_matched_tokens(). + + Returns: + None. + """ + if num_new_matched_tokens == 0: + return + # prepare to launch task + task_id = self.req_id_to_task_dict[request.request_id] + task: FlexKVGetTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot_mapping + num_computed_blocks = task.num_computed_tokens // self.block_size + num_blocks_to_get = num_new_matched_tokens // self.block_size + all_block_ids = blocks.get_block_ids()[0] + block_ids_to_get = all_block_ids[num_computed_blocks:num_computed_blocks+num_blocks_to_get] + task.slot_mapping = np.array(block_ids_to_get).repeat(self.block_size)*self.block_size + + def wait_for_all_get_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all get tasks. + + Returns: + list[FlexKVResponse]: Responses of all get tasks. + """ + return self._blocking_waiting_for_tasks(self.get_tasks) + + #################### + #### Put Method #### + #################### + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> bool: + """ + Args: + request: Request to put. + blocks: All block_ids of the request. + + Returns: + bool: whether thire is unfinished task for this request. + """ + # Task not finished, can't free blocks + if request.request_id in self.req_id_to_task_dict: + return True + + # Abnormal finished, don't put + if not (request.is_finished() and request.get_finished_reason() < 2): + return False + + task_id, num_matched_tokens, num_unmatched_tokens = self._put_match(request=request) + + self.flexkv_stats.record_put(num_all_tokens=request.num_tokens, + num_unmatched_tokens=num_unmatched_tokens) + + if not self._need_to_put(num_all_tokens=request.num_tokens, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens): + return False + + # prepare to launch task + task: FlexKVPutTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot mapping + # num_blocks_to_put = (num_matched_tokens+num_unmatched_tokens) // self.block_size + num_matched_blocks = num_matched_tokens // self.block_size + num_unmatched_tokens = num_unmatched_tokens // self.block_size + block_ids_to_put = block_ids[num_matched_blocks:num_matched_blocks+num_unmatched_tokens] + task.slot_mapping = np.array(block_ids_to_put).repeat(self.block_size)*self.block_size + + return True + + def _put_match( + self, + request: "Request" + ) -> tuple[int, int, int]: + """ + Args: + request: Request to put. + + Returns: + tuple[int, int, int]: A tuple containing three integer values representing + the task_id, number of matched tokens and number of unmatched tokens. + """ + match_start_time = time.perf_counter() + num_tokens_to_put = (cdiv(request.num_tokens+1, self.block_size)-1)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_put] + + if num_tokens_to_put == 0: + return -1, 0, 0 + + np_token_ids = np.array(token_ids) + task_id, unmatched_mask = self.flexkv_manager.put_match(token_ids=np_token_ids) + + num_unmatched_tokens = unmatched_mask.sum().item() + num_matched_tokens = num_tokens_to_put - num_unmatched_tokens + + # Auto cancel if not need to put. + match_end_time = time.perf_counter() + logger.debug(f"Put match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + + if num_unmatched_tokens > 0: + self.req_id_to_task_dict[request.request_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVPutTask(task_id=task_id, + request=request, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + logger.debug(f"FlexKV create put task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_matched_tokens, num_unmatched_tokens + + def _need_to_put( + self, + num_all_tokens: int, + num_matched_tokens: int, + num_unmatched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to put the unmatched blocks from flexkv. + """ + return num_unmatched_tokens > 0 + + def wait_for_all_put_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all put tasks. + + Returns: + list[FlexKVResponse]: Responses of all put tasks. + """ + return self._blocking_waiting_for_tasks(self.put_tasks) + + ####################### + #### Common Method #### + ####################### + + def cancel_tasks(self) -> None: + """ + Cancel tasks in self.cancel_tasks. + Call before launch_tasks() to delete req_id in self.req_id_to_task_dict + """ + # TODO: check if this method is inproc. + if len(self.tasks_to_cancel) == 0: + return + for task in self.tasks_to_cancel.values(): + del self.req_id_to_task_dict[task.request.request_id] + logger.info(f"FlexKV Cancel task: {task}") + self.flexkv_manager.cancel(task_ids=list(self.tasks_to_cancel.keys())) + self.tasks_to_cancel.clear() + + def launch_tasks(self) -> None: + """ + Launch tasks in self.unlaunched_tasks + """ + if len(self.tasks_to_launch) == 0: + return + task_launch_time = time.perf_counter() + task_ids: list[int] = [] + slot_mappings: list[np.ndarray] = [] + + for task_id, task in self.tasks_to_launch.items(): + logger.info(f"FlexKV Launch task: {task}") + task.task_launch_time = task_launch_time + task_ids.append(task_id) + slot_mappings.append(task.slot_mapping) + if isinstance(task, FlexKVGetTask): + self.get_tasks[task_id] = task + else: + self.put_tasks[task_id] = task + self.flexkv_manager.launch(task_ids=task_ids, + slot_mappings=slot_mappings) + self.tasks_to_launch.clear() + + def query_finished_task(self) -> tuple[set[str], set[str]]: + """ + Get response of finished task. + + Returns: + list[FlexKVResponse]: Responses of finished tasks. + """ + if len(self.req_id_to_task_dict) == 0: + return set(), set() + logger.debug(f"unfinished task: {self.req_id_to_task_dict}") + task_ids = list(self.get_tasks.keys()) + list(self.put_tasks.keys()) + responses_from_manager = self.flexkv_manager.try_wait(task_ids) + task_finished_time = time.perf_counter() + # responses_to_return: list[FlexKVResponse] = [] + finished_sending = set() + finished_recving = set() + num_failed_tasks = 0 + for task_id, response in responses_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + if task_id in self.get_tasks: + task = self.get_tasks.pop(task_id) + finished_recving.add(task.request.request_id) + else: + task = self.put_tasks.pop(task_id) + finished_sending.add(task.request.request_id) + del self.req_id_to_task_dict[task.request.request_id] + task.task_finished_time = task_finished_time + if success: + logger.info(f"{task} finished successfully.") + else: + logger.error(f"{task} failed, status: {response.status}.") + num_failed_tasks += 1 + # responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + # request=task.request, success=success)) + self.flexkv_stats.record_faild(num_failed_requests=num_failed_tasks) + return finished_sending, finished_recving + + def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[FlexKVResponse]: + """ + Blocking wait for tasks in task_dict. + + Returns: + list[FlexKVResponse]: Responses of all tasks in task_dict. + """ + if len(task_dict) == 0: + return [] + + task_ids = list(task_dict.keys()) + response_from_manager = self.flexkv_manager.wait(task_ids=task_ids) + task_finished_time = time.perf_counter() + responses_to_return: list[FlexKVResponse] = [] + for task_id, response in response_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + task = task_dict.pop(task_id) + task.task_finished_time = task_finished_time + if success: + logger.info(f"{task} finished successfully.") + else: + logger.error(f"{task} failed, status: {response.status}.") + responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + request=task.request, success=success)) + return responses_to_return + + +class FlexKVWorkerConnector: + def __init__( + self, + flexkv_config: FlexKVConfig, + ): + current_device_id = torch.cuda.current_device() + self.flexkv_config = flexkv_config + logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.server_recv_port}") + self.tp_client = KVTPClient(flexkv_config.server_recv_port, 0, current_device_id) + logger.info("Finish init FlexKVWorkerConnector") + + def register_to_server(self, kv_caches: dict[str, torch.Tensor]): + logger.info("Start register kv_caches") + gpu_blocks = list(kv_caches.values()) + num_layer = len(kv_caches) + if self.flexkv_config.use_mla: + assert gpu_blocks[0].ndim == 3, ( + f"expect kv cached tensor has 3 dim but get shape={gpu_blocks[0].shape}.") + num_blocks = gpu_blocks[0].shape[0] + block_size = gpu_blocks[0].shape[1] + num_kv_heads = 1 + head_size = gpu_blocks[0].shape[2] + else: + assert gpu_blocks[0].ndim == 5, ( + f"expect kv cached tensor has 5 dim but get shape={gpu_blocks[0].shape}.") + num_blocks = gpu_blocks[0].shape[1] + block_size = gpu_blocks[0].shape[2] + num_kv_heads = gpu_blocks[0].shape[3] + head_size = gpu_blocks[0].shape[4] + gpu_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERWISE, + num_layer=num_layer, + num_block=num_blocks, + tokens_per_block=block_size, + num_head=num_kv_heads, + head_size=head_size, + is_mla=self.flexkv_config.use_mla, + ) + self.tp_client.register_to_server(gpu_blocks, gpu_layout) + logger.info("Finish register kv_caches") + + +class FlexKVConnectorV1Impl: + def __init__(self, vllm_config: "VllmConfig", role: "KVConnectorRole"): + self.role = role + flexkv_config = FlexKVConfig.from_env() + flexkv_config.post_init_from_vllm_config(vllm_config) + + if role == KVConnectorRole.SCHEDULER: + self.connector = FlexKVSchedulerConnector(flexkv_config) + elif role == KVConnectorRole.WORKER: + self.connector = FlexKVWorkerConnector(flexkv_config) + else: + raise ValueError(f"Unrecognized KVConnectorRole: {role}.") + + def shutdown(self): + if self.role == KVConnectorRole.SCHEDULER: + self.connector.shutdown() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + self.connector.register_to_server(kv_caches) + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self.connector.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput") -> "KVConnectorMetadata": + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + self.connector.cancel_tasks() + self.connector.launch_tasks() + return KVConnectorMetadata() + + def update_connector_output(self, connector_output: "KVConnectorOutput"): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + + finished_sending, finished_recving = self.connector.query_finished_task() + connector_output.finished_sending = finished_sending + connector_output.finished_recving = finished_recving + + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return self.connector.request_finished(request, block_ids), None diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index fd75750530..8e85ee9959 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -12,566 +12,199 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing as mp -import threading + +from typing import Optional, Tuple, List, Dict, Union, Iterable import time -from dataclasses import dataclass -from queue import Queue -from typing import Dict, Any, Optional -from typing import List, Callable, Union -import nvtx +import numpy as np import torch -from expiring_dict import ExpiringDict -from flexkv.cache.cache_engine import GlobalCacheEngine, TransferOpGraph -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.server.client import KVDPClient +from flexkv.server.server import KVServer, DPClient +from flexkv.kvtask import KVTaskEngine, KVResponse +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.debug import flexkv_logger -from flexkv.common.memory_handle import TensorSharedHandle -from flexkv.common.request import KVRequestType, KVRequest -from flexkv.common.transfer import DeviceType, get_nvtx_range_color, get_nvtx_default_color -from flexkv.common.storage import KVCacheLayout -from flexkv.common.exceptions import LogicError -from flexkv.common.tracer import FlexKVTracer -from flexkv.storage.storage_engine import StorageEngine -from flexkv.transfer.transfer_engine import TransferEngine - - -@dataclass -class RequestTracker: - task_id: int - task_type: KVRequestType - return_mask: torch.Tensor - callback: Optional[Callable] - task_end_ops_ids: List[int] - task_end_ops_status: List[bool] - task_finished: bool = False class KVManager: def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - gpu_layout: Optional[KVCacheLayout] = None, - gpu_blocks: Optional[Dict[int, List[TensorSharedHandle]]] = None): - - flexkv_logger.info(f"Initializing kvmanager...\nmodel_config: {model_config}\ncache_config: {cache_config}") - - mp.set_start_method('spawn', force=True) - self.init_nvtx_range = nvtx.push_range("Initialize kvmanager", color=get_nvtx_default_color()) - - if not cache_config.enable_cpu: - raise ValueError("enable_cpu must be True") - if cache_config.enable_remote and not cache_config.enable_ssd: - raise ValueError("enable_ssd must be True if enable_remote is True") - if not cache_config.enable_cpu and not cache_config.use_gds: - raise ValueError("use_gds must be True if enable_cpu is False") - self.cache_config = cache_config + gpu_register_port: Optional[str] = None, + server_recv_port: Optional[str] = None, + dp_client_id: int = 0): + flexkv_logger.info(f"{model_config = }") + flexkv_logger.info(f"{cache_config = }") self.model_config = model_config - - self._verify_Model_Cache_config(model_config, cache_config) - self.cache_engine = GlobalCacheEngine(cache_config, model_config) - self.storage_engine = StorageEngine(self.model_config, self.cache_config) - - # Initialize tracer - self.tracer = FlexKVTracer(cache_config) - - # Record configuration in tracer - if gpu_layout is not None: - self.tracer.trace_config(model_config, cache_config, gpu_layout) - - - self.transfer_engine: Optional[TransferEngine] = None - self.gpu_layout: Optional[KVCacheLayout] = gpu_layout - - self.running = False - self.requests_tracker: ExpiringDict[int, RequestTracker] = ExpiringDict(1800) # 30 minutes - self.graph_to_request: Dict[int, int] = {} - self.taskid_to_nvtx_range: Dict[int, Any] = {} - self.graphid_to_nvtx_range: Dict[int, Any] = {} - - self._task_id_counter = 0 - self.task_queue: Queue[KVRequest] = Queue() - - if gpu_blocks is None: - gpu_blocks = {} - - self.num_gpus = len(gpu_blocks) - self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = gpu_blocks - - self.lock = threading.Lock() - - if self.num_gpus == self.model_config.tp_size * self.model_config.dp_size: - self._init_transfer_engine() - - # Note that for now only after all the gpu blocks are added, we can initialize the transfer engine - def _init_transfer_engine(self) -> None: - assert self.gpu_layout is not None - assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size - for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): - self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, - self.gpu_layout, - device_id, - dtype=self.model_config.dtype) - self.gpu_handles = [ - self.storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(self.model_config.tp_size * self.model_config.dp_size) - ] - cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) if self.cache_config.enable_cpu else None - ssd_handle = self.storage_engine.get_storage_handle(DeviceType.SSD) if self.cache_config.enable_ssd else None - remote_handle = ( - self.storage_engine.get_storage_handle(DeviceType.REMOTE) - if self.cache_config.enable_remote - else None - ) - self.transfer_engine = TransferEngine(self.gpu_handles, - self.model_config, - self.cache_config, - cpu_handle, - ssd_handle, - remote_handle) - - nvtx.pop_range(self.init_nvtx_range) - - - def is_ready(self) -> bool: - return self.transfer_engine is not None - - def is_running(self) -> bool: - return self.running + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + self.server_recv_port = server_recv_port + self.server_client_mode = model_config.dp_size > 1 + self.dp_client_id = dp_client_id + flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") + if self.server_client_mode: + # server should only be created once but kvmanager will init in every dp rank. + if dp_client_id == 0: + self.server_handle = KVServer.create_server(model_config, + cache_config, + gpu_register_port, + server_recv_port) + + else: + self.server_handle = None + self.dp_client = KVDPClient(self.server_recv_port, self.model_config, dp_client_id) + else: + self.server_handle = None + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port) + + @property + def dpclient_id(self) -> int: + return self.dp_client_id def start(self) -> None: - if self.running: - flexkv_logger.warning("kvmanager is already running") - return - if not self.is_ready(): - raise ValueError("transfer engine is not ready, please add all gpu blocks first") - if self.transfer_engine is not None: - self.transfer_engine.start() - self.running = True + if not self.server_client_mode: + self.kv_task_engine.start() else: - raise ValueError("transfer engine is not initialized, please call start() after all gpu blocks are added") - - self._worker_thread = threading.Thread(target=self._worker_loop) - self._worker_thread.start() - flexkv_logger.info("KVManager fully started and running") + # send the start request to the server + self.dp_client.start_server_and_register() - # the gpu_blocks of multiple gpus can be added post initialization. - # the transfer engine will be initialized after we have all the intended gpu handles. - def register_single_gpu_blocks( - self, - gpu_handles: List[TensorSharedHandle], - gpu_layout: KVCacheLayout, - dp_client_id: int = 0, - tp_rank: int = 0, - ) -> None: - if self.transfer_engine is not None: - raise ValueError("we have already get all gpu blocks") - if self.gpu_layout is None: - self.gpu_layout = gpu_layout - self.tracer.trace_config(self.model_config, self.cache_config, self.gpu_layout) + def is_ready(self) -> bool: + if self.server_client_mode: + return self.dp_client.is_ready() else: - assert self.gpu_layout == gpu_layout - self.all_gpu_blocks[tp_rank + dp_client_id * self.model_config.tp_size] = gpu_handles - self.num_gpus += 1 - if self.num_gpus == self.model_config.tp_size * self.model_config.dp_size: - self._init_transfer_engine() - - def _worker_loop(self) -> None: - assert self.transfer_engine is not None - while self.running: - # deal with completed requests from the cache engine - if not self.task_queue.empty(): - request = self.task_queue.get() - if request.request_type == KVRequestType.SHUTDOWN: - self.shutdown() - break - elif request.request_type == KVRequestType.GET: - nvtx.push_range(f"cache_engine.get request_id: {request.request_id}", - color=get_nvtx_default_color()) - graph, return_mask, callback, task_end_ops_ids = self.cache_engine.get(request.request_id, - request.token_ids, - request.token_mask, - request.slot_mapping, - self.model_config.num_layers, - request.layer_granularity, - request.dp_id) - elif request.request_type == KVRequestType.PUT: - nvtx.push_range(f"cache_engine.put request_id: {request.request_id}", - color=get_nvtx_default_color()) - graph, return_mask, callback, task_end_ops_ids = self.cache_engine.put(request.request_id, - request.token_ids, - request.token_mask, - request.slot_mapping, - self.model_config.num_layers, - request.dp_id) - else: - raise ValueError(f"Unknown request type: {request.request_type}") - nvtx.pop_range() - if graph.num_ops == 0: #early return - flexkv_logger.info(f"no transfer: " - f"request_id = {request.request_id}, request_type = {request.request_type}") - layer_op_num = self.model_config.num_layers // request.layer_granularity \ - if request.request_type == KVRequestType.GET else 1 - self.requests_tracker[request.request_id] = RequestTracker(task_id=request.request_id, - task_type=request.request_type, - return_mask=return_mask, - callback=None, - task_end_ops_ids=[-1]*layer_op_num, - task_end_ops_status=[True]*layer_op_num, - task_finished=True) - else: - self.graph_to_request[graph.graph_id] = request.request_id - self.graphid_to_nvtx_range[graph.graph_id] = nvtx.start_range( - f"request id: {request.request_id}, " - f"graph id: {graph.graph_id}", - color=get_nvtx_range_color(graph.graph_id)) - self.requests_tracker[request.request_id] = RequestTracker(task_id=request.request_id, - task_type=request.request_type, - return_mask=return_mask, - callback=callback, - task_end_ops_ids=task_end_ops_ids, - task_end_ops_status=len(task_end_ops_ids)*[False], - task_finished=False) - self.transfer_engine.submit_transfer_graph(graph) - results = self.transfer_engine.get_completed_graphs_and_ops(timeout=0.001) - for completed_graph_id, completed_op_id in results: - request_id = self.graph_to_request[completed_graph_id] - request_tracker = self.requests_tracker[request_id] - if completed_op_id == -1: - if request_tracker.callback: - request_tracker.callback() - nvtx.end_range(self.graphid_to_nvtx_range[completed_graph_id]) - self.graphid_to_nvtx_range.pop(completed_graph_id) - self.graph_to_request.pop(completed_graph_id) - nvtx.end_range(self.taskid_to_nvtx_range[request_tracker.task_id]) - self.taskid_to_nvtx_range.pop(request_tracker.task_id) - request_tracker.task_finished = True - elif completed_op_id in request_tracker.task_end_ops_ids: - request_tracker.task_end_ops_status[request_tracker.task_end_ops_ids.index(completed_op_id)] = True - self.requests_tracker[request_id] = request_tracker - time.sleep(0.0001) - - def _get_task_id(self) -> int: - with self.lock: - old_value = self._task_id_counter - self._task_id_counter += 1 - return old_value - - def __del__(self) -> None: - if hasattr(self, 'tracer'): - self.tracer.flush() - if self.running: - self.shutdown() + return self.kv_task_engine.is_ready() def shutdown(self) -> None: - self.running = False - # Flush tracer before shutdown - if hasattr(self, 'tracer'): - self.tracer.flush() - flexkv_logger.info("kvmanager shutdown") - self.task_queue.put(KVRequest( - request_type=KVRequestType.SHUTDOWN, - request_id=-1, - token_ids=torch.empty(0), - token_mask=torch.empty(0), - slot_mapping=torch.empty(0), - )) - self._worker_thread.join() - if self.transfer_engine is not None: - self.transfer_engine.shutdown() + if self.server_client_mode: + self.dp_client.shutdown() + else: + self.kv_task_engine.shutdown() def get_async(self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, + token_ids: Union[torch.Tensor, np.ndarray], + slot_mapping: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, dp_id: int = 0, - task_id: int = -1) -> int: - if not self.running: - raise ValueError("kvmanager is not running, please call start() first") - if token_mask is None: - token_mask = torch.ones_like(token_ids) - if layer_granularity == -1: - layer_granularity = self.model_config.num_layers - if task_id == -1: - task_id = self._get_task_id() - # Trace the request - self.tracer.trace_request( - request_type="GET", - request_id=task_id, - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - layer_granularity=layer_granularity, - dp_id=dp_id - ) - nvtx.mark(f"GET request_id: {task_id}") - self.taskid_to_nvtx_range[task_id] = nvtx.start_range(f"GET request_id: {task_id}", - color=get_nvtx_default_color()) - self.task_queue.put(KVRequest( - request_type=KVRequestType.GET, - request_id=task_id, - token_ids=token_ids, - token_mask=token_mask, - slot_mapping=slot_mapping, - layer_granularity=layer_granularity, - dp_id=dp_id, - )) - self.requests_tracker[task_id] = RequestTracker(task_id=task_id, - task_type=KVRequestType.GET, - return_mask=torch.empty(0), - callback=None, - task_end_ops_ids=[], - task_end_ops_status=[], - task_finished=False) + ) -> int: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(slot_mapping, torch.Tensor): + slot_mapping = slot_mapping.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id = self.dp_client.get_async(token_ids, + slot_mapping, + token_mask, + layer_granularity) + else: + task_id, _ = self.kv_task_engine.get_async(token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id) return task_id + def get_match(self, + token_ids: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + layer_granularity: int = -1, + dp_id: int = 0, + ) -> Tuple[int, np.ndarray]: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id, mask = self.dp_client.get_match(token_ids, + token_mask, + layer_granularity) + else: + task_id, mask = self.kv_task_engine.get_match(token_ids, + token_mask, + layer_granularity, + dp_id) + return task_id, mask + def put_async(self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, + token_ids: Union[torch.Tensor, np.ndarray], + slot_mapping: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, dp_id: int = 0, - task_id: int = -1) -> int: - if not self.running: - raise ValueError("kvmanager is not running, please call start() first") - if token_mask is None: - token_mask = torch.ones_like(token_ids) - if task_id == -1: - task_id = self._get_task_id() - # Trace the request - self.tracer.trace_request( - request_type="PUT", - request_id=task_id, - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - dp_id=dp_id - ) - nvtx.mark(f"PUT request_id: {task_id}") - self.taskid_to_nvtx_range[task_id] = nvtx.start_range(f"PUT request_id: {task_id}", - color=get_nvtx_default_color()) - self.task_queue.put(KVRequest( - request_type=KVRequestType.PUT, - request_id=task_id, - token_ids=token_ids, - token_mask=token_mask, - slot_mapping=slot_mapping, - dp_id=dp_id, - )) - self.requests_tracker[task_id] = RequestTracker(task_id=task_id, - task_type=KVRequestType.PUT, - return_mask=torch.empty(0), - callback=None, - task_end_ops_ids=[], - task_end_ops_status=[], - task_finished=False) + ) -> int: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(slot_mapping, torch.Tensor): + slot_mapping = slot_mapping.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask) + else: + task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_id) return task_id - # wait for the key op to be finished - def wait(self, task_ids: Union[int, List[int]], timeout: float = 20.0) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait", - task_ids=task_ids, - ) - nvtx.mark(f"wait task_ids: {task_ids}") + def put_match(self, + token_ids: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + dp_id: int = 0, + ) -> Tuple[int, np.ndarray]: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id, mask = self.dp_client.put_match(token_ids, token_mask) + else: + task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id) + return task_id, mask + + def launch(self, + task_ids: Union[int, List[int]], + slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] - num_completed_tasks = 0 - num_tasks = len(task_ids) - return_masks = {} - start_time = time.time() - while num_completed_tasks < num_tasks: - finished_task_ids = [] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0, dtype=torch.bool) #if not found in tracker, the return mask is an empty tensor - num_completed_tasks += 1 - finished_task_ids.append(task_id) - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.return_mask) == 0: #NOT READY - continue - if all(task_tracker.task_end_ops_status): - num_completed_tasks += 1 - return_masks[task_id] = task_tracker.return_mask - finished_task_ids.append(task_id) - task_ids = [task_id for task_id in task_ids if task_id not in finished_task_ids] - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_ids: {task_ids} timeout, has to return now") - for task_id in task_ids: - return_masks[task_id] = torch.empty(0, dtype=torch.bool) # return mask of timeout task is also an empty tensor - nvtx.mark(f"wait task_ids: {task_ids} timeout") - return return_masks - time.sleep(0.0001) - nvtx.mark(f"wait task_ids: {task_ids} done") - return return_masks + if not isinstance(slot_mappings, List): + slot_mappings = [slot_mappings] + if isinstance(slot_mappings[0], torch.Tensor): + slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] + if self.server_client_mode: + self.dp_client.launch_tasks(task_ids, slot_mappings) + else: + self.kv_task_engine.launch_tasks(task_ids, slot_mappings) - # wait for the whole task to be finished, including the key op and all other ops - # this function is mainly designed for testing to avoid the frequency of writing is too high to use up memory blocks - def wait_for_graph_finished(self, - task_ids: Union[int, List[int]], - timeout: float = 20.0) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait_for_graph_finished", - task_ids=task_ids, - ) - nvtx.mark(f"wait task_ids: {task_ids}") + def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] - num_completed_tasks = 0 - return_masks = {} - start_time = time.time() - while num_completed_tasks < len(task_ids): - finished_task_ids = [] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - num_completed_tasks += 1 - finished_task_ids.append(task_id) - continue - task_tracker = self.requests_tracker[task_id] - if task_tracker.task_finished: - num_completed_tasks += 1 - return_masks[task_id] = task_tracker.return_mask - finished_task_ids.append(task_id) - task_ids = [task_id for task_id in task_ids if task_id not in finished_task_ids] - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_ids: {task_ids} timeout, has to return now") - for task_id in task_ids: - return_masks[task_id] = torch.empty(0) # return mask of timeout task is also an empty tensor - nvtx.mark(f"wait task_ids: {task_ids} timeout") - return return_masks - time.sleep(0.0001) - nvtx.mark(f"wait task_ids: {task_ids} done") - return return_masks + if self.server_client_mode: + self.dp_client.cancel_tasks(task_ids) + else: + self.kv_task_engine.cancel_tasks(task_ids) - # the try_wait api is used for server-client mode: - # server process running the kvmanager should NOT be blocked by any single client - def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="try_wait", - task_ids=task_ids, - ) - return_masks: Dict[int, torch.Tensor] = {} + def wait(self, + task_ids: Union[int, List[int]], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - mask = None - elif all(task_tracker.task_end_ops_status): - mask = task_tracker.return_mask - return_masks[task_id] = mask - else: - mask = None - - return return_masks - - def wait_at_layer_group(self, task_id: int, layer_group_id: int, timeout: float = 20.0) -> torch.Tensor: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait_at_layer_group", - task_ids=task_id, - layer_group_id=layer_group_id - ) - nvtx.mark(f"wait task_id: {task_id}, layer_group_id: {layer_group_id}") - start_time = time.time() - while True: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return torch.empty(0) #if not found in tracker, the return mask is an empty tensor - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - continue - if task_tracker.task_end_ops_status[layer_group_id]: - return task_tracker.return_mask - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_id: {task_id}, layer_group_id: {layer_group_id} " - f"timeout, has to return now") - return torch.empty(0) # return mask of timeout task is an empty tensor - time.sleep(0.0001) - - # nvtx.mark(f"wait_at_layer_group task_id: {task_id}, layer_group_id: {layer_group_id} done") - # return return_mask + if self.server_client_mode: + return self.dp_client.wait(task_ids, timeout, completely) + else: + return self.kv_task_engine.wait(task_ids, timeout, completely) - def try_wait_at_layer_group(self, - task_ids: Union[int, List[int]], - layer_group_id: int) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="try_wait_at_layer_group", - task_ids=task_ids, - layer_group_id=layer_group_id, - ) - return_masks: Dict[int, torch.Tensor] = {} + def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - mask = torch.empty(0) - elif task_tracker.task_end_ops_status[layer_group_id]: - mask = task_tracker.return_mask - else: - mask = torch.empty(0) - return_masks[task_id] = mask - return return_masks - - def _verify_Model_Cache_config(self, - model_config: ModelConfig, - cache_config: CacheConfig): - if cache_config.enable_remote: - if cache_config.remote_cache_path is None: - - if cache_config.remote_file_prefix is None: - raise ValueError("remote_file_prefix must be provided when remote_cache_path is None") - - if cache_config.remote_file_num is None or cache_config.remote_file_num <= 0: - raise ValueError("remote_file_num must be a positive integer") - - cache_config.remote_cache_path = [ - f"{cache_config.remote_file_prefix}_{i}" - for i in range(cache_config.remote_file_num) - ] - - if cache_config.remote_cache_size_mode == "block_num": - if cache_config.num_remote_blocks is None: - raise ValueError("num_remote_blocks must not None if use block_num model") - elif cache_config.remote_cache_size_mode == "file_size": - if cache_config.remote_file_size is None: - raise ValueError("remote_file_size must not None if use file_size model") - if model_config.use_mla: - kv_size = ( - model_config.num_layers - * cache_config.tokens_per_block - * model_config.num_kv_heads - * model_config.head_size - * model_config.dtype.itemsize - ) - else: - kv_size = ( - model_config.num_layers - * 2 - * cache_config.tokens_per_block - * model_config.num_kv_heads - * model_config.head_size - * model_config.dtype.itemsize - ) - cache_config.num_remote_blocks = cache_config.remote_file_size // kv_size * cache_config.remote_file_num + if self.server_client_mode: + return self.dp_client.try_wait(task_ids) + else: + return self.kv_task_engine.try_wait(task_ids) - else: - raise ValueError("remote_cache_size_mode must block_num or file_size model") + # Only for testing + def _clear_cpu_cache(self) -> None: + if self.server_client_mode: + flexkv_logger.error("clear_cache is not supported in server client mode") + return + else: + self.kv_task_engine._clear_cpu_cache() diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py new file mode 100644 index 0000000000..3d661a66bf --- /dev/null +++ b/flexkv/kvtask.py @@ -0,0 +1,518 @@ +import time +from typing import Dict, Optional, List, Union, Tuple +import threading +from enum import Enum +from dataclasses import dataclass +from typing import Callable + + +from expiring_dict import ExpiringDict +import nvtx +import torch +import numpy as np + +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.debug import flexkv_logger +from flexkv.common.transfer import TransferOpGraph, get_nvtx_default_color +from flexkv.common.tracer import FlexKVTracer +from flexkv.cache.cache_engine import GlobalCacheEngine +from flexkv.transfer_manager import TransferManagerHandle +from flexkv.common.request import KVResponseStatus, KVResponse + +class TaskStatus(Enum): + # slot mapping is not ready + UNREADY = "unready" + # waiting for the task to be launched + READY = "ready" + # in transfer + RUNNING = "running" + # transfer completed + COMPLETED = "completed" + # transfer cancelled + CANCELLED = "cancelled" + # transfer failed + FAILED = "failed" + +class TaskType(Enum): + GET = "get" + PUT = "put" + +@dataclass +class KVTask: + # task descriptor + task_id: int + task_type: TaskType + task_end_op_id: int + task_end_op_finished: bool + status: TaskStatus + + # params + token_ids: np.ndarray + slot_mapping: np.ndarray + token_mask: Optional[np.ndarray] + dp_id: int + + # cache engine return + graph: TransferOpGraph + return_mask: np.ndarray + callback: Optional[Callable] + + def is_completed(self) -> bool: + return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED] + +TASK_STATUS_TO_RESPONSE_STATUS = { + TaskStatus.COMPLETED: KVResponseStatus.SUCCESS, + TaskStatus.CANCELLED: KVResponseStatus.CANCELLED, + TaskStatus.FAILED: KVResponseStatus.FAILED, + TaskStatus.RUNNING: KVResponseStatus.SUCCESS, # for early return: still running, but success +} + +def convert_to_response_status(task_status: TaskStatus) -> KVResponseStatus: + return TASK_STATUS_TO_RESPONSE_STATUS[task_status] + +class KVTaskManager: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: Optional[str] = None, + use_separate_process: bool = True, + ): + if not cache_config.enable_cpu: + raise ValueError("enable_cpu must be True") + if cache_config.enable_remote and not cache_config.enable_ssd: + raise ValueError("enable_ssd must be True if enable_remote is True") + if not cache_config.enable_cpu and not cache_config.use_gds: + raise ValueError("use_gds must be True if enable_cpu is False") + self.cache_config = cache_config + self.model_config = model_config + self._check_config(model_config, cache_config) + + self.cache_engine = GlobalCacheEngine(cache_config, model_config) + + self.transfer_handle = TransferManagerHandle( + self.model_config, + self.cache_config, + use_separate_process=use_separate_process, + gpu_register_port=gpu_register_port + ) + + self.tasks: ExpiringDict[int, KVTask] = ExpiringDict(max_age_seconds=1800, max_len=100000) # 30 minutes + self.graph_to_task: Dict[int, int] = {} + + self.task_id_counter = 0 + self.task_id_lock = threading.Lock() + + self.running_tasks: int = 0 + + def start(self) -> None: + self.transfer_handle.start() + + def is_ready(self) -> bool: + return self.transfer_handle.is_ready() + + def __del__(self) -> None: + self.shutdown() + + def shutdown(self) -> None: + self.transfer_handle.shutdown() + + def create_get_task(self, + task_id: int, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + is_fake_slot_mapping: bool = False, + ) -> None: + if task_id in self.tasks: + raise ValueError(f"Task ID {task_id} already exists") + graph, return_mask, callback, task_end_op_id = self.cache_engine.get(task_id, + token_ids, + token_mask, + slot_mapping, + self.model_config.num_layers, + layer_granularity, + dp_id) + self.tasks[task_id] = KVTask( + task_id=task_id, + task_type=TaskType.GET, + task_end_op_id=task_end_op_id, + task_end_op_finished=False, + status=TaskStatus.UNREADY if is_fake_slot_mapping else TaskStatus.READY, + token_ids=token_ids, + slot_mapping=slot_mapping, + token_mask=token_mask, + dp_id=dp_id, + graph=graph, + return_mask=return_mask, + callback=callback) + + self.graph_to_task[graph.graph_id] = task_id + + def create_put_task(self, + task_id: int, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + is_fake_slot_mapping: bool = False, + ) -> None: + if task_id in self.tasks: + raise ValueError(f"Task ID {task_id} already exists") + graph, return_mask, callback, task_end_op_id = self.cache_engine.put(task_id, + token_ids, + token_mask, + slot_mapping, + self.model_config.num_layers, + dp_id) + self.tasks[task_id] = KVTask( + task_id=task_id, + task_type=TaskType.PUT, + task_end_op_id=task_end_op_id, + task_end_op_finished=False, + status=TaskStatus.UNREADY if is_fake_slot_mapping else TaskStatus.READY, + token_ids=token_ids, + slot_mapping=slot_mapping, + token_mask=token_mask, + dp_id=dp_id, + graph=graph, + return_mask=return_mask, + callback=callback) + self.graph_to_task[graph.graph_id] = task_id + + def _launch_task(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + return + if task.status != TaskStatus.READY: + raise ValueError(f"Task {task_id} status is {task.status}, cannot launch") + transfer_graph = task.graph + task.status = TaskStatus.RUNNING + nvtx.mark(f"launch task: task_id={task_id}, graph_id={transfer_graph.graph_id}") + if transfer_graph.num_ops > 0: + self.transfer_handle.submit(transfer_graph) + + def _update_tasks(self, timeout: float = 0.001) -> None: + completed_ops = self._get_completed_ops(timeout) + for completed_graph_id, completed_op_id in completed_ops: + if completed_graph_id not in self.graph_to_task: + continue + task_id = self.graph_to_task[completed_graph_id] + task = self.tasks[task_id] + if completed_op_id == -1: # the graph is totally finished + self._mark_completed(task_id) + elif completed_op_id == task.task_end_op_id: + self.tasks[task_id].task_end_op_finished = True + + def _cancel_task(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + flexkv_logger.warning(f"Task {task_id} is already completed, cannot cancel") + return + if task.status == TaskStatus.RUNNING: + flexkv_logger.warning(f"Task {task_id} is running, cannot cancel") + return + if task.status == TaskStatus.CANCELLED: + flexkv_logger.warning(f"Task {task_id} is already cancelled, cannot cancel") + return + task.status = TaskStatus.CANCELLED + self.graph_to_task.pop(task.graph.graph_id, None) + + def check_completed(self, task_id: int, completely: bool = False) -> bool: + self._process_empty_graph(task_id) + task = self.tasks[task_id] + if completely: + return task.is_completed() + return task.is_completed() or task.task_end_op_finished + + def set_slot_mappings(self, + task_ids: List[int], + slot_mappings: List[np.ndarray]) -> None: + for task_id, slot_mapping in zip(task_ids, slot_mappings): + self._set_slot_mapping_impl(task_id, slot_mapping) + + def _set_slot_mapping_impl(self, task_id: int, slot_mapping: np.ndarray) -> None: + task = self.tasks[task_id] + if task.status != TaskStatus.UNREADY: + return + graph_ids = self.cache_engine.slot_mapping_to_block_ids(slot_mapping, + self.cache_config.tokens_per_block) + task.graph.set_gpu_blocks(graph_ids) + task.status = TaskStatus.READY + + def _gen_task_id(self) -> int: + with self.task_id_lock: + old_value = self.task_id_counter + self.task_id_counter += 1 + return old_value + + def _mark_completed(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + return + if task.callback: + task.callback() + task.status = TaskStatus.COMPLETED + task.task_end_op_finished = True + self.graph_to_task.pop(task.graph.graph_id) + + def _process_empty_graph(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.graph.num_ops == 0: + self._mark_completed(task_id) + + def _get_completed_ops(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_handle.wait(timeout) + + def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> None: + if cache_config.enable_remote: + if cache_config.remote_cache_path is None: + + if cache_config.remote_file_prefix is None: + raise ValueError("remote_file_prefix must be provided when remote_cache_path is None") + + if cache_config.remote_file_num is None or cache_config.remote_file_num <= 0: + raise ValueError("remote_file_num must be a positive integer") + + cache_config.remote_cache_path = [ + f"{cache_config.remote_file_prefix}_{i}" + for i in range(cache_config.remote_file_num) + ] + + if cache_config.remote_cache_size_mode == "block_num": + if cache_config.num_remote_blocks is None: + raise ValueError("num_remote_blocks must not None if use block_num model") + elif cache_config.remote_cache_size_mode == "file_size": + if cache_config.remote_file_size is None: + raise ValueError("remote_file_size must not None if use file_size model") + if model_config.use_mla: + kv_size = ( + model_config.num_layers + * cache_config.tokens_per_block + * model_config.num_kv_heads + * model_config.head_size + * model_config.dtype.itemsize + ) + else: + kv_size = ( + model_config.num_layers + * 2 + * cache_config.tokens_per_block + * model_config.num_kv_heads + * model_config.head_size + * model_config.dtype.itemsize + ) + cache_config.num_remote_blocks = cache_config.remote_file_size // kv_size * cache_config.remote_file_num + + else: + raise ValueError("remote_cache_size_mode must block_num or file_size model") + + +class KVTaskEngine(KVTaskManager): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: Optional[str] = None, + use_separate_process: bool = True, + ): + super().__init__(model_config, cache_config, gpu_register_port, use_separate_process) + self.tracer = FlexKVTracer(cache_config) + + def get_async(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + task_id, return_mask = self._get_match_impl(token_ids, + slot_mapping, + is_fake_slot_mapping=False, + token_mask=token_mask, + layer_granularity=layer_granularity, + dp_id=dp_id, + task_id=task_id) + self._launch_task(task_id) + return task_id, return_mask + + def put_async(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + task_id, return_mask = self._put_match_impl(token_ids, + slot_mapping, + is_fake_slot_mapping=False, + token_mask=token_mask, + dp_id=dp_id, + task_id=task_id) + self._launch_task(task_id) + return task_id, return_mask + + def _wait_impl(self, + task_ids: List[int], + timeout: float = 20.0, + completely: bool = False, + only_return_finished: bool = False, + ) -> Dict[int, KVResponse]: + return_responses = {} + start_time = time.time() + is_timeout = timeout == 0.0 + + self._update_tasks(timeout=0) + + for task_id in task_ids: + while True: + if task_id not in self.tasks: + flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") + return_responses[task_id] = KVResponse( + status=KVResponseStatus.NOTFOUND, + task_id=task_id, + return_mask=None + ) + break + elif self.tasks[task_id].status == TaskStatus.UNREADY: + flexkv_logger.warning(f"task_id {task_id} is unready") + return_responses[task_id] = KVResponse( + status=KVResponseStatus.UNREADY, + task_id=task_id, + return_mask=None + ) + break + elif self.check_completed(task_id, completely=completely): + return_responses[task_id] = KVResponse( + status=convert_to_response_status(self.tasks[task_id].status), + task_id=task_id, + return_mask=self.tasks[task_id].return_mask + ) + break + elif only_return_finished: + break + elif time.time() - start_time > timeout: + is_timeout = True + if is_timeout: + return_responses[task_id] = KVResponse( + status=KVResponseStatus.TIMEOUT, + task_id=task_id, + return_mask=None + ) + break + self._update_tasks(timeout=0.001) + return return_responses + + def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: + if isinstance(task_ids, int): + task_ids = [task_ids] + nvtx.mark(f"try_wait task_ids: {task_ids}") + return_responses = self._wait_impl(task_ids, + completely=False, + only_return_finished=True) + return return_responses + + def wait(self, + task_ids: Union[int, List[int]], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: + if isinstance(task_ids, int): + task_ids = [task_ids] + nvtx.push_range(f"wait task_ids: {task_ids}", color=get_nvtx_default_color()) + return_responses = self._wait_impl(task_ids, timeout, completely=completely) + nvtx.pop_range() + return return_responses + + def get_match(self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids, dtype=bool) + fake_slot_mapping = np.zeros_like(token_ids[token_mask]) + return self._get_match_impl(token_ids, + fake_slot_mapping, + is_fake_slot_mapping=True, + token_mask=token_mask, + layer_granularity=layer_granularity, + dp_id=dp_id, + task_id=task_id) + + def _get_match_impl(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + is_fake_slot_mapping: bool = False, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids) + if layer_granularity == -1: + layer_granularity = self.model_config.num_layers + if task_id == -1: + task_id = self._gen_task_id() + nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) + self.create_get_task(task_id, + token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id, + is_fake_slot_mapping=is_fake_slot_mapping) + self._process_empty_graph(task_id) + nvtx.pop_range() + return task_id, self.tasks[task_id].return_mask + + def put_match(self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + fake_slot_mapping = np.zeros_like(token_ids) + return self._put_match_impl(token_ids, + fake_slot_mapping, + is_fake_slot_mapping=True, + token_mask=token_mask, + dp_id=dp_id, + task_id=task_id) + + def _put_match_impl(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + is_fake_slot_mapping: bool = False, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids) + if task_id == -1: + task_id = self._gen_task_id() + nvtx.push_range(f"put match: task_id={task_id}", color=get_nvtx_default_color()) + self.create_put_task(task_id, + token_ids, + slot_mapping, + token_mask, + dp_id, + is_fake_slot_mapping=is_fake_slot_mapping) + self._process_empty_graph(task_id) + nvtx.pop_range() + return task_id, self.tasks[task_id].return_mask + + def launch_tasks(self, + task_ids: List[int], + slot_mappings: List[np.ndarray]) -> None: + assert isinstance(slot_mappings[0], np.ndarray) + self.set_slot_mappings(task_ids, slot_mappings) + for task_id in task_ids: + self._launch_task(task_id) + + def cancel_tasks(self, task_ids: Union[int, List[int]]) -> None: + if isinstance(task_ids, int): + task_ids = [task_ids] + for task_id in task_ids: + self._cancel_task(task_id) + + def _clear_cpu_cache(self) -> None: + self.cache_engine.cpu_cache_engine.reset() diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 257c3f2fa3..1643af98a3 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -2,16 +2,18 @@ from multiprocessing import Lock, Queue from multiprocessing.connection import Connection from queue import Queue as ThreadQueue -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Callable import tempfile import torch import zmq +import numpy as np -from flexkv.common.config import ModelConfig +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout +from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.server.utils import get_zmq_socket from flexkv.server.request import ( RegisterDPClientRequest, @@ -19,30 +21,36 @@ IsReadyRequest, PutRequest, GetRequest, + PutMatchRequest, + GetMatchRequest, + LaunchTaskRequest, + CancelTaskRequest, WaitRequest, TryWaitRequest, CheckRunningRequest, + StartRequest, ShutdownRequest, Response ) - class KVDPClient: def __init__( self, server_recv_port: str, model_config: ModelConfig, + dp_client_id: int, ): # Init inter-process communication context = zmq.Context(2) self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) - client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" + self.client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" self.recv_from_server = get_zmq_socket( - context, zmq.SocketType.PULL, client_recv_port, True + context, zmq.SocketType.PULL, self.client_recv_port, True ) - self.dp_client_id = self.register_to_server(model_config, client_recv_port) + self.dp_client_id = dp_client_id + self.model_config = model_config self._task_id_range = (self.dp_client_id * 10000000, (self.dp_client_id + 1) * 10000000) self._task_id_counter = self._task_id_range[0] @@ -57,22 +65,20 @@ def _get_task_id(self) -> int: self._task_id_counter = self._task_id_range[0] return old_value + def start_server_and_register(self) -> None: + #start server and register + req = StartRequest(self.dp_client_id) + self.send_to_server.send_pyobj(req) + self.register_to_server(self.model_config, self.client_recv_port) + def register_to_server( self, model_config: ModelConfig, client_recv_port: str, - ) -> int: - register_req = RegisterDPClientRequest(model_config, client_recv_port) - + ) -> None: + register_req = RegisterDPClientRequest(self.dp_client_id, model_config, client_recv_port) self.send_to_server.send_pyobj(register_req) - # blocking - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"DP client registered successfully! DP client id: {response.dp_client_id}") - return response.dp_client_id - else: - flexkv_logger.error(f"DP client registeration fialed: {response.error_msg}") - raise + flexkv_logger.info(f"DP client {self.dp_client_id} registered to server request sent!") def is_ready( self, @@ -80,61 +86,103 @@ def is_ready( req = IsReadyRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.success: - return response.is_ready - else: - flexkv_logger.error(f"is_ready failed: {response.error_msg}") - raise - + return response.is_ready + def put_async( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], - ) -> Optional[int]: - # start_time = time.time() + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray], + ) -> int: req = PutRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, + token_ids, + slot_mapping, + token_mask if token_mask is not None else None, self._get_task_id()) self.send_to_server.send_pyobj(req) - # end_time = time.time() - # print(f"[dpclient] put_async task: {req.task_id} created. time: {(end_time - start_time)*1000:.2f}ms") return req.task_id + def put_match( + self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray], + ) -> Optional[Tuple[int, np.ndarray]]: + req = PutMatchRequest(self.dp_client_id, + token_ids, + token_mask if token_mask is not None else None, + self._get_task_id()) + self.send_to_server.send_pyobj(req) + response: Response = self.recv_from_server.recv_pyobj() + if response.error_msg is None: + return response.task_id, response.mask + else: + flexkv_logger.error(f"put_match failed, error_msg: {response.error_msg}") + return None + def get_async( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], - ) -> Optional[int]: - # start_time = time.time() + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray], + layer_granularity: int, + ) -> int: req = GetRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, - self._get_task_id()) - + token_ids, + slot_mapping, + token_mask if token_mask is not None else None, + self._get_task_id(), + layer_granularity) self.send_to_server.send_pyobj(req) - # end_time = time.time() - # print(f"[dpclient] get_async task: {req.task_id} created. time: {(end_time - start_time)*1000:.2f}ms") return req.task_id + def get_match( + self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray], + layer_granularity: int, + ) -> Optional[Tuple[int, np.ndarray]]: + req = GetMatchRequest(self.dp_client_id, + token_ids, + token_mask if token_mask is not None else None, + layer_granularity, + self._get_task_id()) + self.send_to_server.send_pyobj(req) + response: Response = self.recv_from_server.recv_pyobj() + if response.error_msg is None: + return req.task_id, response.mask + else: + flexkv_logger.error(f"get_match failed, error_msg: {response.error_msg}") + return None + + def launch_tasks( + self, + task_ids: List[int], + slot_mappings: List[np.ndarray], + ) -> None: + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings) + self.send_to_server.send_pyobj(req) + + def cancel_task( + self, + task_ids: List[int], + ) -> None: + req = CancelTaskRequest(self.dp_client_id, task_ids) + self.send_to_server.send_pyobj(req) + def wait( self, wait_task_ids: List[int], wait_timeout: float = 20.0, - ) -> Optional[Dict[int, torch.Tensor]]: - req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout) - + completely: bool = False, + ) -> Optional[Dict[int, KVResponse]]: + req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout, completely) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.masks is not None: - response.masks = {k: torch.from_numpy(v) for k, v in response.masks.items()} - if response.success: - # flexkv_logger.info(f"wait tasks: {wait_task_ids} finished.") - return response.masks + if response.status is not None: + for k, v in response.status.items(): + if v.status != KVResponseStatus.SUCCESS: + flexkv_logger.error(f"wait task {k} failed: {v.status}") + return response.status else: flexkv_logger.error(f"wait tasks: {wait_task_ids} in DP {self.dp_client_id} failed.") return None @@ -142,35 +190,23 @@ def wait( def try_wait( self, try_wait_task_ids: List[int], - ) -> Optional[Dict[int, torch.Tensor]]: + ) -> Optional[Dict[int, KVResponse]]: req = TryWaitRequest(self.dp_client_id, None, try_wait_task_ids) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.masks is not None: - response.masks = {k: torch.from_numpy(v) for k, v in response.masks.items()} - if response.success: - # flexkv_logger.info(f"try_wait tasks: {try_wait_task_ids} finished.") - return response.masks + if response.status is not None: + for k, v in response.status.items(): + if v.status != KVResponseStatus.SUCCESS: + flexkv_logger.error(f"try_wait task {k} failed: {v.status}") + return response.status else: flexkv_logger.error(f"try_wait tasks: {try_wait_task_ids} in DP {self.dp_client_id} failed.") return None - def check_running(self) -> bool: - req = CheckRunningRequest(self.dp_client_id) - self.send_to_server.send_pyobj(req) - response: Response = self.recv_from_server.recv_pyobj() - return response.running - def shutdown(self) -> None: req = ShutdownRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"DP client {self.dp_client_id} shutdown successfully.") - else: - flexkv_logger.error(f"DP client {self.dp_client_id} shutdown failed.") - raise class KVTPClient: def __init__( @@ -178,23 +214,17 @@ def __init__( server_recv_port: str, dp_client_id: int, device_id: int, - tp_rank: int, ): # Init inter-process communication context = zmq.Context(2) self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) - self.client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - self.recv_from_server = get_zmq_socket( - context, zmq.SocketType.PULL, self.client_recv_port, True - ) self.dp_client_id = dp_client_id self.device_id = device_id - self.tp_rank = tp_rank - flexkv_logger.info(f"KVTPClient {tp_rank} of KVDPClient {self.dp_client_id} Initialized!") + flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized!") def register_to_server( self, @@ -214,24 +244,12 @@ def register_to_server( register_req = RegisterTPClientRequest( self.dp_client_id, - self.tp_rank, self.device_id, - self.client_recv_port, handles, kv_layout ) - self.send_to_server.send_pyobj(register_req) - # blocking - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"TP client of DP client {self.dp_client_id} registered successfully!") - else: - flexkv_logger.error( - f"TP client of DP client {self.dp_client_id} registeration fialed: {response.error_msg}" - ) - raise - + self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) if __name__ == "__main__": diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 532feae047..f6df970b17 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -2,15 +2,16 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import torch from flexkv.common.config import ModelConfig from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout +from flexkv.common.request import KVResponseStatus @dataclass class RegisterDPClientRequest: + dp_client_id: int model_config: ModelConfig client_recv_port: str @@ -18,9 +19,7 @@ class RegisterDPClientRequest: @dataclass class RegisterTPClientRequest: dp_client_id: int - tp_rank: int device_id: int - client_recv_port: str handles: List[TensorSharedHandle] gpu_layout: KVCacheLayout @@ -44,6 +43,33 @@ class GetRequest: slot_mapping: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + layer_granularity: int = -1 + +@dataclass +class PutMatchRequest: + dp_client_id: int + token_ids: np.ndarray + token_mask: Optional[np.ndarray] + task_id: int = -1 + +@dataclass +class GetMatchRequest: + dp_client_id: int + token_ids: np.ndarray + token_mask: Optional[np.ndarray] + layer_granularity: int + task_id: int = -1 + +@dataclass +class LaunchTaskRequest: + dp_client_id: int + task_ids: List[int] + slot_mappings: List[np.ndarray] + +@dataclass +class CancelTaskRequest: + dp_client_id: int + task_ids: List[int] @dataclass class WaitRequest: @@ -51,6 +77,7 @@ class WaitRequest: tp_rank: Optional[int] wait_task_ids: List[int] wait_timeout: float = 20.0 + completely: bool = False # Used for async put/get @dataclass @@ -62,19 +89,25 @@ class TryWaitRequest: @dataclass class Response: - dp_client_id: int + dp_client_id: int = -1 task_id: Optional[int] = None - masks: Optional[Dict[int, np.ndarray]] = None - success: bool = True - running: bool = False - error_msg: str = "" + mask: Optional[Dict[int, np.ndarray]] = None + status: Optional[Dict[int, KVResponseStatus]] = None is_ready: bool = False + error_msg: Optional[str] = None + @property + def success(self) -> bool: + return self.status is not None and \ + all(self.status[task_id] == KVResponseStatus.SUCCESS for task_id in self.status.keys()) @dataclass -class ShutdownRequest: +class StartRequest: dp_client_id: int +@dataclass +class ShutdownRequest: + dp_client_id: int @dataclass class CheckRunningRequest: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 8ea5b91faa..1849c1e304 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -7,12 +7,15 @@ import time import threading from threading import Lock +import multiprocessing as mp +import socket +import os from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVTaskEngine from flexkv.server.utils import get_zmq_socket from flexkv.server.request import ( RegisterDPClientRequest, @@ -20,27 +23,19 @@ IsReadyRequest, PutRequest, GetRequest, + PutMatchRequest, + GetMatchRequest, + LaunchTaskRequest, + CancelTaskRequest, WaitRequest, TryWaitRequest, Response, + StartRequest, ShutdownRequest, CheckRunningRequest, ) import contextlib - -class TPClient: - def __init__( - self, - send_to_client: zmq.Socket, - tp_rank: int = 0, - device_id: int = 0, - ): - self.tp_rank = tp_rank - self.device_id = device_id - self.send_to_client = send_to_client - - class DPClient: def __init__( self, @@ -50,40 +45,11 @@ def __init__( ): self.client_id = client_id self.tp_size = tp_size - self.tp_client_dict: Dict[int, TPClient] = {} self.send_to_client = send_to_client self.is_ready: bool = False - def register_tp_client( - self, - context: zmq.Context, - client_recv_port: str, - tp_rank: int = 0, - device_id: int = 0, - ) -> None: - if tp_rank in self.tp_client_dict: - flexkv_logger.error(f"TP rank: {tp_rank} in DP client: {self.client_id} has already registered.") - raise - if tp_rank >= self.tp_size: - flexkv_logger.error(f"TP rank: {tp_rank} is larger than TP size of DP client: {self.client_id}.") - raise - - send_to_client = get_zmq_socket( - context, zmq.SocketType.PUSH, client_recv_port, False - ) - - self.tp_client_dict[tp_rank] = TPClient(send_to_client, tp_rank, device_id) - - flexkv_logger.info(f"TP rank: {tp_rank} in DP client: {self.client_id} registered successfully.") - - if len(self.tp_client_dict) == self.tp_size: - self.is_ready = True - flexkv_logger.info(f"All the TP clients in DP client: {self.client_id} has registered. " - f"DP client: {self.client_id} is ready!") - - class ClientManager: def __init__( self, @@ -116,20 +82,6 @@ def register_dp_client( return client_id - def register_tp_client( - self, - context: zmq.Context, - dp_client_id: int, - client_recv_port: str, - tp_rank: int, - device_id: int - ) -> None: - if dp_client_id not in self.client_dict: - flexkv_logger.error(f"DP client: {dp_client_id} has not registered.") - raise - self.client_dict[dp_client_id].register_tp_client( - context, client_recv_port, tp_rank, device_id) - def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -150,166 +102,127 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: return self.client_dict[dp_client_id].is_ready return False +class KVServerHandle: + def __init__(self, process: mp.Process): + self.process = process + + def shutdown(self) -> None: + self.process.join(timeout=5) + if self.process.is_alive(): + flexkv_logger.info("force terminate the server process") + self.process.terminate() + self.process.join() + + def __del__(self) -> None: + if self.process.is_alive(): + self.shutdown() class KVServer: def __init__( self, model_config: ModelConfig, cache_config: CacheConfig, - server_recv_port: Optional[str] = None, + gpu_register_port: str, + server_recv_port: str ): # Init inter-process communication self.context = zmq.Context(2) - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, server_recv_port, True) self.client_manager = ClientManager(max_num_dp_client=model_config.dp_size) - self.kvmanager = KVManager(model_config, cache_config) - - if self.kvmanager.is_ready(): - flexkv_logger.info("KVManager is ready, starting with worker initialization...") - self.kvmanager.start() + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, False) self.req_counter = 0 + self._is_ready = False + self._running = False - flexkv_logger.info(f"Server Initialized! [Recv Port]: {server_recv_port}") - # self._running = True - + # Request handler dispatch table + self.request_handlers = { + StartRequest: self._handle_start_request, + RegisterDPClientRequest: self._handle_register_dp_client_request, + IsReadyRequest: self._handle_is_ready_request, + GetRequest: self._handle_get_request, + PutRequest: self._handle_put_request, + GetMatchRequest: self._handle_get_match_request, + PutMatchRequest: self._handle_put_match_request, + WaitRequest: self._handle_wait_request, + LaunchTaskRequest: self._handle_launch_task_request, + CancelTaskRequest: self._handle_cancel_task_request, + TryWaitRequest: self._handle_try_wait_request, + ShutdownRequest: self._handle_shutdown_request, + } + + def is_ready(self) -> bool: + return self._is_ready + + def start_server(self) -> None: + self.kv_task_engine.start() + self._is_ready = True + + @staticmethod + def _server_process(model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str, + server_recv_port: str) -> None: + + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) + server.run() + + @classmethod + def create_server(cls, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str, + server_recv_port: Optional[str] = None) -> 'KVServerHandle': + #if server_recv_port is None: + # server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" #TODO unify this + + # Set spawn method for CUDA compatibility + with contextlib.suppress(RuntimeError): + mp.set_start_method("spawn") + process = mp.Process(target=cls._server_process, + args=(model_config, cache_config, gpu_register_port, server_recv_port)) + process.start() + flexkv_logger.info(f"KVServer process started, PID: {process.pid}") + + return KVServerHandle(process) def run(self) -> None: """Main server loop""" # TODO: handle error and return error response # TODO: support check finish + flexkv_logger.info("Servering waiting to be started") + req = self.recv_from_client.recv_pyobj() + if isinstance(req, StartRequest): + flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}, " + f"Starting server...") + self.start_server() + else: + raise TypeError(f"Received RequestType: {type(req)} from DP client " + f"{req.dp_client_id} before the start request") self._running = True while self._running: try: - flexkv_logger.info("start wait for req") + flexkv_logger.info("start waiting for req") req = self.recv_from_client.recv_pyobj() flexkv_logger.info(f"recv req: {type(req)}") - # register dp client - if isinstance(req, RegisterDPClientRequest): - self._verify_model_config(req.model_config) - client_id = self.client_manager.register_dp_client( - self.context, - req.client_recv_port, - req.model_config.tp_size - ) - response = Response(client_id) - result_zmq = self.client_manager.get_zmq(client_id) - result_zmq.send_pyobj(response) - - - elif isinstance(req, RegisterTPClientRequest): - self.client_manager.register_tp_client( - self.context, - req.dp_client_id, - req.client_recv_port, - req.tp_rank, - req.device_id, - ) - - # register GPU Memory - self.kvmanager.register_single_gpu_blocks(req.handles, - req.gpu_layout, - req.dp_client_id, - req.tp_rank) - - response = Response(req.dp_client_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id, req.tp_rank) - result_zmq.send_pyobj(response) - - if self.kvmanager.is_ready(): - flexkv_logger.info("All TP clients registered, starting KVManager...") - self.kvmanager.start() - - elif isinstance(req, IsReadyRequest): - is_ready = self.kvmanager.is_ready() - response = Response(req.dp_client_id, is_ready=is_ready) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, GetRequest): - assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kvmanager.get_async( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - layer_granularity=-1, - dp_id=req.dp_client_id, - task_id=req.task_id, - ) - if req.task_id == -1: - response = Response(req.dp_client_id, req_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, PutRequest): - assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kvmanager.put_async( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - dp_id=req.dp_client_id, - task_id=req.task_id, - ) - if req.task_id == -1: - response = Response(req.dp_client_id, req_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, WaitRequest): - # TODO: support TP client wait - masks = self.kvmanager.wait( - req.wait_task_ids, - timeout=req.wait_timeout, - ) - if masks is not None: - # Convert to numpy arrays for serialization - masks = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in masks.items()} - response = Response(req.dp_client_id, masks=masks) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, TryWaitRequest): - # TODO: support TP client try_wait - masks = self.kvmanager.try_wait( - req.try_wait_task_ids, - ) - if masks is not None: - # Convert to numpy arrays for serialization - masks = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in masks.items()} - response = Response(req.dp_client_id, masks=masks) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, ShutdownRequest): - flexkv_logger.info(f"Received shutdown request from DP client {req.dp_client_id}") - # Gracefully shutdown the server - self._running = False - # Send response back to client - response = Response(req.dp_client_id, success=True) - result_zmq = self.client_manager.get_zmq(req.dp_client_id) - result_zmq.send_pyobj(response) - break + # Use dispatch table for request handling + req_type = type(req) + handler = self.request_handlers.get(req_type) - elif isinstance(req, CheckRunningRequest): - response = Response(req.dp_client_id, success=True, running=self.kvmanager.is_running()) - result_zmq = self.client_manager.get_zmq(req.dp_client_id) - result_zmq.send_pyobj(response) + if handler is None: + raise TypeError(f"Unrecognized RequestType: {req_type}") - else: - raise TypeError(f"Unregonized RequestType: {type(req)}") + # Call the corresponding handler method + handler(req) + + # If the request is a shutdown request, exit the loop + if req_type == ShutdownRequest: + break except zmq.ZMQError as e: flexkv_logger.error(f"ZMQ Error: {e}", exc_info=True) @@ -329,339 +242,110 @@ def _verify_model_config( # TODO return True - def __del__(self) -> None: - self.kvmanager.shutdown() - - -class SchedulerServer: - """ - Scheduler server that merges the functionality of KVServer and KVDPClient. - Note that this class is ONLY FOR CASES WHEN DP_SIZE = 1. - - This class can: - 1. Directly call KVManager methods to avoid inter-process communication latency - 2. Accept registration requests from TPClient - 3. Provide the same interface as KVDPClient (put_async, get_async, wait, try_wait) - """ - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - server_recv_port: Optional[str] = None, - ): - self.model_config = model_config - self.cache_config = cache_config - - # Initialize KVManager (similar to KVServer) - self.kvmanager = KVManager(model_config, cache_config) - - # Start KVManager if it's ready (e.g., when no TP clients are needed) - if self.kvmanager.is_ready(): - try: - self.kvmanager.start() - flexkv_logger.info("KVManager started during initialization") - except Exception as e: - flexkv_logger.warning(f"KVManager start failed during initialization: {e}") + # Request Handler Methods - # For TPClient compatibility, we need a server to receive TPClient registration requests - self.context = zmq.Context(2) - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - - self.server_recv_port = server_recv_port - self.recv_from_client = get_zmq_socket( - self.context, zmq.SocketType.PULL, server_recv_port, True) - - # Manage TP clients - self.tp_size = model_config.tp_size - self.tp_client_dict: Dict[int, TPClient] = {} - self.is_ready: bool = False - - # DP client related - self.dp_client_id = 0 # Fixed to 0 because we merged scheduler and server - self._task_id_range = (self.dp_client_id * 10000000, (self.dp_client_id + 1) * 10000000) - self._task_id_counter = self._task_id_range[0] - self._task_id_lock = Lock() - - # Server thread control - self._running = False - self._server_thread = None + def _handle_start_request(self, req: StartRequest) -> None: + """Handle start request""" + flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}") - flexkv_logger.info(f"SchedulerServer Initialized! [Recv Port]: {server_recv_port}") - - def _get_task_id(self) -> int: - """Generate unique task ID""" - with self._task_id_lock: - old_value = self._task_id_counter - self._task_id_counter += 1 - if self._task_id_counter >= self._task_id_range[1]: - self._task_id_counter = self._task_id_range[0] - return old_value - - def start_server_thread(self) -> None: - """Start background server thread to handle TPClient requests""" - if self._server_thread is not None and self._server_thread.is_alive(): - flexkv_logger.warning("Server thread is already running") - return - - self._running = True - self._server_thread = threading.Thread(target=self._server_loop, daemon=True) - self._server_thread.start() - flexkv_logger.info("SchedulerServer background thread started") - - def _server_loop(self) -> None: - """Background server loop to handle requests from TPClient""" - while self._running: - try: - # Set non-blocking receive to allow checking _running status - try: - req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) - except zmq.Again: - time.sleep(0.001) # Brief sleep to avoid busy waiting - continue - - flexkv_logger.info(f"SchedulerServer received request: {type(req)}") - - if isinstance(req, RegisterTPClientRequest): - self._handle_tp_registration(req) - elif isinstance(req, ShutdownRequest): - flexkv_logger.info("Received shutdown request from TP client") - response = Response(req.dp_client_id, success=True) - # Since we don't know which TP client sent the shutdown request, - # we send response to all registered TP clients - self._running = False - for tp_client in self.tp_client_dict.values(): - tp_client.send_to_client.send_pyobj(response) - break - else: - flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}") - - except zmq.ZMQError as e: - if e.errno == zmq.ETERM: - break # Context terminated - flexkv_logger.error(f"ZMQ Error in SchedulerServer: {e}", exc_info=True) - except Exception as e: - flexkv_logger.error(f"Error in SchedulerServer: {e}", exc_info=True) - time.sleep(0.0001) - - flexkv_logger.info("SchedulerServer background thread stopped") - - def _handle_tp_registration(self, req: RegisterTPClientRequest) -> None: - """Handle TP Client registration request""" - tp_rank = req.tp_rank - - if tp_rank in self.tp_client_dict: - flexkv_logger.error(f"TP rank: {tp_rank} has already registered.") - response = Response(req.dp_client_id, success=False, - error_msg=f"TP rank {tp_rank} already registered") - elif tp_rank >= self.tp_size: - flexkv_logger.error(f"TP rank: {tp_rank} is larger than TP size: {self.tp_size}.") - response = Response(req.dp_client_id, success=False, - error_msg=f"TP rank {tp_rank} exceeds TP size {self.tp_size}") - else: - try: - # Create connection to TP client - send_to_client = get_zmq_socket( - self.context, zmq.SocketType.PUSH, req.client_recv_port, False - ) - - self.tp_client_dict[tp_rank] = TPClient(send_to_client, tp_rank, req.device_id) - - # Register GPU Memory to KVManager - self.kvmanager.register_single_gpu_blocks( - req.handles, - req.gpu_layout, - self.dp_client_id, # Use fixed dp_client_id = 0 - req.tp_rank - ) - - flexkv_logger.info(f"TP rank: {tp_rank} registered successfully.") - - # Check if all TP clients have registered - if len(self.tp_client_dict) == self.tp_size: - self.is_ready = True - # Always start kvmanager when all TP clients are registered - try: - flexkv_logger.info("All TP clients registered, starting KVManager...") - self.kvmanager.start() - flexkv_logger.info("KVManager started successfully") - except Exception as e: - flexkv_logger.warning(f"KVManager start failed or already started: {e}") - flexkv_logger.info("All TP clients registered. SchedulerServer is ready!") - - response = Response(req.dp_client_id, success=True) - - except Exception as e: - flexkv_logger.error(f"Failed to register TP client {tp_rank}: {e}") - response = Response(req.dp_client_id, success=False, error_msg=str(e)) + def _handle_register_dp_client_request(self, req: RegisterDPClientRequest) -> None: + """Handle DP client registration request""" + self._verify_model_config(req.model_config) + client_id = self.client_manager.register_dp_client( + self.context, + req.client_recv_port, + req.model_config.tp_size + ) + flexkv_logger.info(f"DP client {client_id} registered successfully") - # Send response to TP client - if tp_rank in self.tp_client_dict: - self.tp_client_dict[tp_rank].send_to_client.send_pyobj(response) + def _handle_is_ready_request(self, req: IsReadyRequest) -> None: + """Handle ready state check request""" + is_ready = self.kv_task_engine.is_ready() + response = Response(req.dp_client_id, is_ready=is_ready) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_get_request(self, req: GetRequest) -> None: + """Handle Get request""" + req_id = self.kv_task_engine.get_async( + task_id=req.task_id, + token_ids=req.token_ids, + slot_mapping=req.slot_mapping, + token_mask=req.token_mask, + layer_granularity=req.layer_granularity, + dp_id=req.dp_client_id, + ) - def put_async( - self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, - ) -> Optional[int]: - """ - Asynchronous PUT operation, directly calling KVManager (no network communication required) - - Args: - token_ids: Token IDs tensor - slot_mapping: Slot mapping tensor - token_mask: Optional token mask tensor - - Returns: - Task ID if successful, None otherwise - """ - start_time = time.time() - - if not self.is_ready: - flexkv_logger.error("SchedulerServer is not ready (not all TP clients registered)") - return None - - try: - task_id = self._get_task_id() - req_id = self.kvmanager.put_async( - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - dp_id=self.dp_client_id, - task_id=task_id, - ) - - end_time = time.time() - flexkv_logger.info(f"[SchedulerServer] put_async task: {task_id} created. " - f"time: {(end_time - start_time)*1000:.2f}ms") - return task_id - - except Exception as e: - flexkv_logger.error(f"put_async failed: {e}") - return None - - def get_async( - self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, - ) -> Optional[int]: - """ - Asynchronous GET operation, directly calling KVManager (no network communication required) - - Args: - token_ids: Token IDs tensor - slot_mapping: Slot mapping tensor - token_mask: Optional token mask tensor - - Returns: - Task ID if successful, None otherwise - """ - start_time = time.time() - - if not self.is_ready: - flexkv_logger.error("SchedulerServer is not ready (not all TP clients registered)") - return None - - try: - task_id = self._get_task_id() - req_id = self.kvmanager.get_async( - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - layer_granularity=-1, - dp_id=self.dp_client_id, - task_id=task_id, - ) - - end_time = time.time() - flexkv_logger.info(f"[SchedulerServer] get_async task: {task_id} created. " - f"time: {(end_time - start_time)*1000:.2f}ms") - return task_id - - except Exception as e: - flexkv_logger.error(f"get_async failed: {e}") - return None - - def wait( - self, - wait_task_ids: List[int], - wait_timeout: float = 20.0, - ) -> Optional[Dict[int, torch.Tensor]]: - """ - Wait for specified tasks to complete, directly calling KVManager (no network communication required) - - Args: - wait_task_ids: List of task IDs to wait for - - Returns: - Dictionary mapping task IDs to result masks, None if failed - """ - try: - masks = self.kvmanager.wait(wait_task_ids, timeout=wait_timeout) - flexkv_logger.info(f"[SchedulerServer] wait tasks: {wait_task_ids} finished.") - return masks - - except Exception as e: - flexkv_logger.error(f"wait failed: {e}") - return None - - def try_wait( - self, - try_wait_task_ids: List[int], - ) -> Optional[Dict[int, torch.Tensor]]: - """ - Non-blocking wait for specified tasks, directly calling KVManager (no network communication required) - - Args: - try_wait_task_ids: List of task IDs to try waiting for - - Returns: - Dictionary mapping task IDs to result masks, None if not ready or failed - """ - try: - masks = self.kvmanager.try_wait(try_wait_task_ids) - if masks is not None: - flexkv_logger.info(f"[SchedulerServer] try_wait tasks: {try_wait_task_ids} finished.") - return masks - - except Exception as e: - flexkv_logger.error(f"try_wait failed: {e}") - return None - - def check_running(self) -> bool: - return self.kvmanager.is_running() + def _handle_put_request(self, req: PutRequest) -> None: + """Handle Put request""" + req_id = self.kv_task_engine.put_async( + token_ids=req.token_ids, + slot_mapping=req.slot_mapping, + token_mask=req.token_mask, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) - def shutdown(self) -> None: - """Shutdown SchedulerServer""" - flexkv_logger.info("Shutting down SchedulerServer...") + def _handle_get_match_request(self, req: GetMatchRequest) -> None: + """Handle GetMatch request""" + req_id, mask = self.kv_task_engine.get_match( + token_ids=req.token_ids, + token_mask=req.token_mask, + layer_granularity=req.layer_granularity, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_put_match_request(self, req: PutMatchRequest) -> None: + """Handle PutMatch request""" + req_id, mask = self.kv_task_engine.put_match( + token_ids=req.token_ids, + token_mask=req.token_mask, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: + """Handle LaunchTask request""" + self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings) + + def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: + """Handle CancelTask request""" + self.kv_task_engine.cancel_tasks(req.task_ids) + + def _handle_wait_request(self, req: WaitRequest) -> None: + """Handle Wait request""" + kv_responses = self.kv_task_engine.wait( + req.wait_task_ids, + timeout=req.wait_timeout, + completely=req.completely, + ) + response = Response(req.dp_client_id, status=kv_responses) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_try_wait_request(self, req: TryWaitRequest) -> None: + """Handle TryWait request""" + kv_responses = self.kv_task_engine.try_wait( + req.try_wait_task_ids, + ) + response = Response(req.dp_client_id, status=kv_responses) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) - # Stop server thread + def _handle_shutdown_request(self, req: ShutdownRequest) -> None: + """Handle shutdown request""" + flexkv_logger.info(f"Received shutdown request from DP client {req.dp_client_id}") self._running = False - if self._server_thread is not None and self._server_thread.is_alive(): - self._server_thread.join(timeout=5.0) - - # Shutdown KVManager - if hasattr(self, 'kvmanager'): - self.kvmanager.shutdown() - - # Close ZMQ context - #if hasattr(self, 'context'): - # self.context.term() - - flexkv_logger.info("SchedulerServer shutdown complete") - - def get_server_port(self) -> str: - """Get server receive port for TPClient to use""" - return self.server_recv_port def __del__(self) -> None: - """Destructor""" - with contextlib.suppress(Exception): - self.shutdown() - + self.kv_task_engine.shutdown() if __name__ == "__main__": import torch @@ -694,7 +378,6 @@ def __del__(self) -> None: enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks,) diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index 7cd38156e0..ed683e6505 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -95,7 +95,6 @@ def allocate(cls, layout: KVCacheLayout, dtype: torch.dtype, **kwargs: Any) -> StorageHandle: - pin_memory = kwargs.get("pin_memory", True) total_size = layout.get_total_elements() # although the kv layout may have multiple dimensions, we only have one-dim CPU tensor flexkv_logger.info(f"CPU allocate total_size: {2 * total_size/1024/1024/1024} GB") @@ -103,7 +102,7 @@ def allocate(cls, size=(total_size,), dtype=dtype, device="cpu", - pin_memory=pin_memory, + pin_memory=False, ) return StorageHandle( handle_type=AccessHandleType.TENSOR, diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 0762b0062d..0d48fe6230 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -35,7 +35,6 @@ def __init__(self, device_type=DeviceType.CPU, layout=self._cpu_layout, dtype=self._model_config.dtype, - pin_memory=self._cache_config.use_pinned_memory, ) if self._cache_config.enable_ssd: if not self._cache_config.ssd_kv_layout_type == self._cpu_layout.type: diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 30516e6c3f..1b4036cbe3 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -37,8 +37,19 @@ tpGPUCPUTransferWorker, ) from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.ring_buffer import SharedOpPool +def register_op_to_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + op.src_slot_id = pin_buffer.allocate_slot(op.src_block_ids) + op.dst_slot_id = pin_buffer.allocate_slot(op.dst_block_ids) + +def free_op_from_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + if op.src_slot_id != -1: + pin_buffer.free_slot(op.src_slot_id) + if op.dst_slot_id != -1: + pin_buffer.free_slot(op.dst_slot_id) + class TransferEngine: def __init__(self, gpu_handles: List[StorageHandle], @@ -71,6 +82,8 @@ def __init__(self, self._remote_handle = remote_handle self._cache_config = cache_config + self.pin_buffer = SharedOpPool(2048, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) + self.op_id_to_nvtx_range: Dict[int, str] = {} self.dp_size = model_config.dp_size @@ -88,8 +101,8 @@ def _init_workers(self) -> None: if self.tp_size == 1: self.gpucpu_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( - worker_id=i, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), gpu_blocks=self.gpu_handles[i].get_tensor_handle_list(), cpu_blocks=self._cpu_handle.get_tensor(), gpu_kv_layout=self.gpu_handles[i].kv_layout, @@ -106,12 +119,12 @@ def _init_workers(self) -> None: else: self.gpucpu_workers = [ tpGPUCPUTransferWorker.create_worker( - worker_id=i, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layout=self.gpu_handles[i].kv_layout, + gpu_kv_layouts=[self.gpu_handles[i].kv_layout for i in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=self.gpu_handles[i].dtype, tp_group_size=self.tp_size, @@ -128,8 +141,8 @@ def _init_workers(self) -> None: if self._ssd_handle is not None and self._cpu_handle is not None: self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - worker_id=10, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -139,8 +152,8 @@ def _init_workers(self) -> None: cache_config=self._cache_config, ) self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - worker_id=11, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -153,8 +166,8 @@ def _init_workers(self) -> None: self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( - worker_id=20, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -163,8 +176,8 @@ def _init_workers(self) -> None: remote_config_custom=self._remote_handle.remote_config_custom, ) self.remotecpu_write_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( - worker_id=21, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -177,18 +190,19 @@ def _init_workers(self) -> None: if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") # Wait for all workers to ready - for worker in self._worker_map.values(): + for transfer_type, worker in self._worker_map.items(): if isinstance(worker, List): for w in worker: - w.ready_event.wait(timeout=60) + flexkv_logger.info(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") + w.ready_event.wait() + flexkv_logger.info(f"{transfer_type.name} worker {w.worker_id} is ready") else: - flexkv_logger.info(f"waiting for worker {worker} to ready") - worker.ready_event.wait(timeout=60) - flexkv_logger.info(f"worker {worker} is ready") + flexkv_logger.info(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") + worker.ready_event.wait() + flexkv_logger.info(f"{transfer_type.name} worker {worker.worker_id} is ready") # Start scheduler thread self._running = True self._scheduler_thread = threading.Thread(target=self._scheduler_loop) - flexkv_logger.info("TransferEngine initialized and running") self._scheduler_thread.start() def start(self) -> None: @@ -212,6 +226,7 @@ def _scheduler_loop(self) -> None: try: op_id = self.finished_ops_queue.get_nowait() op = self.op_id_to_op[op_id] + free_op_from_buffer(op, self.pin_buffer) self.completed_queue.put((op.graph_id, op.op_id)) finished_ops.append(op) del self.op_id_to_op[op_id] @@ -229,6 +244,8 @@ def _scheduler_loop(self) -> None: self.completed_queue.put((op.graph_id, op.op_id)) else: self.op_id_to_op[op.op_id] = op + # copy block ids into buffer and update slot id info + register_op_to_buffer(op, self.pin_buffer) self._assign_op_to_worker(op) # Handle completed graphs for graph_id in completed_graph_ids: @@ -286,6 +303,8 @@ def get_completed_graphs_and_ops(self, timeout: Optional[float] = None) -> List[ def shutdown(self) -> None: """Shutdown the transfer engine""" try: + if not self._running: + return self._running = False self._scheduler_thread.join(timeout=5) diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 4d72a0fb93..2dbbc1be86 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -1,10 +1,10 @@ import copy -import multiprocessing as mp +import torch.multiprocessing as mp import threading import time from abc import ABC, abstractmethod from dataclasses import dataclass -from multiprocessing import Queue as MPQueue, Pipe as MPPipe +from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe from multiprocessing.connection import Connection from threading import Thread from typing import List, Any, Dict, Union, Optional @@ -51,31 +51,55 @@ class WorkerTransferOp: transfer_op_id: int transfer_graph_id: int transfer_type: TransferType - src_block_ids: np.ndarray - dst_block_ids: np.ndarray layer_id: int layer_granularity: int + src_slot_id: int + dst_slot_id: int + valid_block_num: int + src_block_ids: np.ndarray + dst_block_ids: np.ndarray # successors: List[int] def __init__(self, transfer_op: TransferOp): self.transfer_op_id = transfer_op.op_id self.transfer_graph_id = transfer_op.graph_id self.transfer_type = transfer_op.transfer_type - self.src_block_ids = transfer_op.src_descriptor.physical_block_ids.numpy() - self.dst_block_ids = transfer_op.dst_descriptor.physical_block_ids.numpy() self.layer_id = transfer_op.layer_id self.layer_granularity = transfer_op.layer_granularity + self.src_slot_id = transfer_op.src_slot_id + self.dst_slot_id = transfer_op.dst_slot_id + self.valid_block_num = transfer_op.valid_block_num + if self.src_slot_id == -1: + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids + else: + self.src_block_ids = np.empty(0) + self.dst_block_ids = np.empty(0) # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): + _worker_id_counter = 0 + _worker_id_lock = threading.Lock() + def __init__(self, worker_id: int, transfer_conn: Connection, # receive end of pipe - finished_ops_queue: MPQueue): + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor): self.worker_id = worker_id self.transfer_conn = transfer_conn # receive end of pipe self.finished_ops_queue: MPQueue[int] = finished_ops_queue + self.op_buffer_tensor = op_buffer_tensor + cudaHostRegister(self.op_buffer_tensor) + + @classmethod + def _get_worker_id(cls) -> int: + with cls._worker_id_lock: + worker_id = cls._worker_id_counter + cls._worker_id_counter += 1 + return worker_id + def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: if isinstance(layer_blocks, torch.Tensor): layer_blocks = [layer_blocks] @@ -90,14 +114,18 @@ def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) return layer_ptrs @classmethod - def create_worker(cls, worker_id: int, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) -> 'WorkerHandle': + def create_worker(cls, + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, + *args: Any, **kwargs: Any) -> 'WorkerHandle': """Generic worker creation template method""" parent_conn, child_conn = MPPipe() # create pipe ready_event = mp.Event() + worker_id = cls._get_worker_id() process = mp.Process( target=cls._worker_process, - args=(worker_id, child_conn, finished_ops_queue, ready_event, *args), + args=(worker_id, child_conn, finished_ops_queue, op_buffer_tensor, ready_event, *args), kwargs=kwargs, daemon=True ) @@ -107,16 +135,16 @@ def create_worker(cls, worker_id: int, finished_ops_queue: MPQueue, *args: Any, @classmethod def _worker_process(cls, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - ready_event: Any, *args: Any, **kwargs: Any) -> None: - worker = cls(worker_id, transfer_conn, finished_ops_queue, *args, **kwargs) + op_buffer_tensor: torch.Tensor, ready_event: Any, *args: Any, **kwargs: Any) -> None: + worker = cls(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor, *args, **kwargs) ready_event.set() worker.run() @abstractmethod def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, @@ -124,6 +152,35 @@ def _transfer_impl( ) -> None: pass + def get_transfer_block_ids(self, + transfer_op: WorkerTransferOp, + pinned: bool = True) ->tuple[torch.Tensor, torch.Tensor]: + """ + Get transfer block ids from op buffer tensor or directly from op + Args: + transfer_op: WorkerTransferOp + pinned: whether to pin the block ids tensor + Returns: + tuple[torch.Tensor, torch.Tensor]: src_block_ids and dst_block_ids + """ + src_slot_id = transfer_op.src_slot_id + dst_slot_id = transfer_op.dst_slot_id + valid_block_num = transfer_op.valid_block_num + if src_slot_id == -1: + src_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64) + if pinned: + src_block_ids = src_block_ids.pin_memory() + else: + src_block_ids = self.op_buffer_tensor[src_slot_id, :valid_block_num] + if dst_slot_id == -1: + dst_block_ids = torch.from_numpy(transfer_op.dst_block_ids).to(dtype=torch.int64) + if pinned: + dst_block_ids = dst_block_ids.pin_memory() + else: + dst_block_ids = self.op_buffer_tensor[dst_slot_id, :valid_block_num] + + return src_block_ids, dst_block_ids + def _log_transfer_performance(self, transfer_op: WorkerTransferOp, transfer_size: int, @@ -159,7 +216,7 @@ def run(self) -> None: try: nvtx.push_range(f"launch {op.transfer_type.name} op_id: {op.transfer_op_id}, " f"graph_id: {op.transfer_graph_id}, " - f"num_blocks: {len(op.src_block_ids)}", + f"num_blocks: {op.valid_block_num}", color=get_nvtx_range_color(op.transfer_graph_id)) self.launch_transfer(op) nvtx.pop_range() @@ -180,7 +237,7 @@ class WorkerHandle: """handle for worker process""" def __init__(self, worker_id: int, transfer_conn: Connection, process: mp.Process, ready_event: Any): self.worker_id = worker_id - self.transfer_conn = transfer_conn # send end of pipe + self.transfer_conn = transfer_conn self.process = process self.ready_event = ready_event @@ -209,6 +266,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, gpu_blocks: List[TensorSharedHandle], cpu_blocks: torch.Tensor, gpu_kv_layout: KVCacheLayout, @@ -220,7 +278,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8) -> None: # initialize worker in a new process - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) # Register CPU tensors with CUDA cudaHostRegister(cpu_blocks) self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] @@ -258,33 +316,30 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, ) -> None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2D: - gpu_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + gpu_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids use_ce_transfer = self.use_ce_transfer_h2d transfer_sms = self.transfer_sms_h2d elif transfer_type == TransferType.D2H: - gpu_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + gpu_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids use_ce_transfer = self.use_ce_transfer_d2h transfer_sms = self.transfer_sms_d2h else: raise ValueError(f"Invalid transfer type: {transfer_type} for GPUCPUTransferWorker") - gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() - assert len(gpu_block_id_list) == len(cpu_block_id_list) if len(gpu_block_id_list) == 0: @@ -320,11 +375,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + with torch.cuda.stream(self.transfer_stream): start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, layer_id, layer_granularity, @@ -332,7 +389,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -346,9 +403,10 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], cpu_blocks: torch.Tensor, - gpu_kv_layout: KVCacheLayout, + gpu_kv_layouts: List[KVCacheLayout], cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, @@ -358,7 +416,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8): - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size # Handle tensor import for multi-process case imported_gpu_blocks = [] @@ -369,7 +427,7 @@ def __init__(self, imported_gpu_blocks.append(blocks_in_one_gpu) self.gpu_blocks = imported_gpu_blocks self.dtype = dtype - self.is_mla = gpu_kv_layout.is_mla + self.is_mla = gpu_kv_layouts[0].is_mla self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size @@ -377,19 +435,22 @@ def __init__(self, cudaHostRegister(cpu_blocks) - self.num_layers = gpu_kv_layout.num_layer - gpu_kv_layout_per_layer = gpu_kv_layout.div_layer(self.num_layers) + self.num_layers = gpu_kv_layouts[0].num_layer + gpu_kv_layouts_per_layer = [gpu_kv_layout.div_layer(self.num_layers) for gpu_kv_layout in gpu_kv_layouts] - self.gpu_chunk_size_in_bytes = gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize - self.gpu_kv_stride_in_bytes = gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize - self.gpu_block_stride_in_bytes = gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + self.gpu_block_strides_in_bytes = [gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize - if not gpu_kv_layout.type == KVCacheLayoutType.LAYERWISE: + if not gpu_kv_layouts[0].type == KVCacheLayoutType.LAYERWISE: raise ValueError("Only layerwise layout is supported for GPU") self.transfer_sms_h2d = transfer_sms_h2d @@ -397,46 +458,47 @@ def __init__(self, self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h - self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id) + gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) + gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) + gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) + + self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, + gpu_kv_strides_tensor, gpu_block_strides_tensor, gpu_chunk_sizes_tensor) + def _transfer_impl(self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, )->None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2D: - gpu_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + gpu_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids use_ce_transfer = self.use_ce_transfer_h2d transfer_sms = self.transfer_sms_h2d elif transfer_type == TransferType.D2H: - gpu_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + gpu_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids use_ce_transfer = self.use_ce_transfer_d2h transfer_sms = self.transfer_sms_d2h else: raise ValueError(f"Invalid transfer type: {transfer_type} for tpGPUCPUTransferWorker") - gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() assert len(gpu_block_id_list) == len(cpu_block_id_list) if len(gpu_block_id_list) == 0: return - + self.tp_transfer_thread_group.tp_group_transfer( gpu_block_id_list, - self.gpu_kv_stride_in_bytes, - self.gpu_block_stride_in_bytes, - self.gpu_chunk_size_in_bytes, cpu_block_id_list, self.cpu_kv_stride_in_bytes, self.cpu_layer_stride_in_bytes, @@ -459,10 +521,12 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, layer_id, layer_granularity, @@ -470,7 +534,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -484,6 +548,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, cpu_blocks: torch.Tensor, ssd_files: Dict[int, List[str]], # ssd_device_id -> file_paths cpu_kv_layout: KVCacheLayout, @@ -491,7 +556,7 @@ def __init__(self, dtype: torch.dtype, num_blocks_per_file: int, cache_config: CacheConfig): - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) @@ -528,30 +593,26 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, ) -> None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2DISK: - ssd_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + ssd_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids elif transfer_type == TransferType.DISK2H: - ssd_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + ssd_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids else: raise ValueError(f"Invalid transfer type: {transfer_type} for CPUSSDDiskTransferWorker") - # this means partial read hit cpu and other hit ssd - # or partial write hit ssd and none hit cpu - ssd_block_id_list = torch.from_numpy(ssd_block_ids).to(dtype=torch.int64) - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64) layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) @@ -581,10 +642,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers + + src_block_ids , dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, transfer_op.layer_id, transfer_op.layer_granularity, @@ -592,7 +656,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -606,6 +670,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, cpu_blocks: List[torch.Tensor], remote_file: List[str], cpu_kv_layout: KVCacheLayout, @@ -614,7 +679,7 @@ def __init__(self, remote_config_custom: Dict[str, Any]): if transfer_kv_blocks_remote is None: raise RuntimeError("transfer_kv_blocks_remote not available, please build with FLEXKV_ENABLE_CFS=1") - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.remote_files = remote_file @@ -687,15 +752,15 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any ) -> None: - assert dst_block_ids.dtype == np.int64 - assert src_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if layer_id == -1: @@ -707,17 +772,14 @@ def _transfer_impl( # or partial write hit remote and none hit cpu if transfer_type == TransferType.H2REMOTE: - remote_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + remote_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids elif transfer_type == TransferType.REMOTE2H: - remote_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + remote_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids else: raise ValueError(f"Invalid transfer type: {transfer_type} for CPUSSDDiskTransferWorker") - remote_block_id_list = torch.from_numpy(remote_block_ids).pin_memory().to(dtype=torch.int64) - cpu_block_id_list = torch.from_numpy(cpu_block_ids).pin_memory().to(dtype=torch.int64) - layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) transfer_kv_blocks_remote( file_nodeid_list=self.file_nodeid_list, @@ -748,10 +810,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers + + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, transfer_op.layer_id, transfer_op.layer_granularity, @@ -759,7 +824,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py new file mode 100644 index 0000000000..fab65f345c --- /dev/null +++ b/flexkv/transfer_manager.py @@ -0,0 +1,322 @@ +import multiprocessing as mp +import time +import queue +from queue import Queue +from typing import Dict, Optional, List, Tuple +from abc import ABC, abstractmethod +from multiprocessing import Process, Pipe, Event +import zmq +import tempfile +import threading +import numpy as np + +from flexkv.common.transfer import TransferOpGraph +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.debug import flexkv_logger +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.common.transfer import DeviceType +from flexkv.common.storage import KVCacheLayout +from flexkv.storage.storage_engine import StorageEngine +from flexkv.transfer.transfer_engine import TransferEngine +from flexkv.server.utils import get_zmq_socket +from flexkv.server.request import RegisterTPClientRequest, Response + + +class TransferManager: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + + self.all_gpu_layouts: Dict[int, KVCacheLayout] = {} + self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks + + self.context = zmq.Context(2) + self.recv_from_client = get_zmq_socket( + self.context, zmq.SocketType.PULL, gpu_register_port, True) + + self.transfer_engine: Optional[TransferEngine] = None + self.storage_engine = StorageEngine(self.model_config, self.cache_config) + + def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: + device_id = req.device_id + + if device_id in self.all_gpu_blocks: + flexkv_logger.error(f"GPU {device_id} has already registered.") + elif device_id >= self.model_config.tp_size * self.model_config.dp_size: + flexkv_logger.error(f"GPU {device_id} is larger than TP size: " + f"{self.model_config.tp_size * self.model_config.dp_size}.") + else: + try: + self.all_gpu_blocks[device_id] = req.handles + self.all_gpu_layouts[device_id] = req.gpu_layout + flexkv_logger.info(f"GPU {device_id} registered successfully") + except Exception as e: + flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") + + + def _register_gpu_blocks_via_socket(self) -> None: + try: + flexkv_logger.info(f"GPU tensor registration server started on port {self.gpu_register_port}") + + expected_gpus = self.model_config.tp_size * self.model_config.dp_size + + while len(self.all_gpu_blocks) < expected_gpus: + try: + req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) + except zmq.Again: + time.sleep(0.001) + continue + + if isinstance(req, RegisterTPClientRequest): + flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}") + self._handle_gpu_blocks_registration(req) + else: + flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}") + + flexkv_logger.info(f"All {expected_gpus} GPUs registered successfully") + + except Exception as e: + flexkv_logger.error(f"Error in GPU registration server: {e}") + raise + finally: + pass + # TODO: fix the socket close issue + # self.recv_from_client.close() + # self.context.term() + + def initialize_transfer_engine(self) -> None: + self._register_gpu_blocks_via_socket() + + assert len(self.all_gpu_layouts) == self.model_config.tp_size * self.model_config.dp_size + assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size + for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): + self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, + self.all_gpu_layouts[device_id], + device_id, + dtype=self.model_config.dtype) + self.gpu_handles = [ + self.storage_engine.get_storage_handle(DeviceType.GPU, i) + for i in range(self.model_config.tp_size * self.model_config.dp_size) + ] + cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) \ + if self.cache_config.enable_cpu else None + ssd_handle = self.storage_engine.get_storage_handle(DeviceType.SSD) \ + if self.cache_config.enable_ssd else None + remote_handle = ( + self.storage_engine.get_storage_handle(DeviceType.REMOTE) \ + if self.cache_config.enable_remote \ + else None + ) + self.transfer_engine = TransferEngine(gpu_handles=self.gpu_handles, + model_config=self.model_config, + cache_config=self.cache_config, + cpu_handle=cpu_handle, + ssd_handle=ssd_handle, + remote_handle=remote_handle) + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.transfer_engine.submit_transfer_graph(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_engine.get_completed_graphs_and_ops(timeout) + + def start(self) -> None: + self.transfer_engine.start() + + def shutdown(self) -> None: + self.transfer_engine.shutdown() + + +class TransferManagerHandleBase(ABC): + @abstractmethod + def start(self) -> None: + pass + + @abstractmethod + def is_ready(self) -> bool: + pass + + @abstractmethod + def submit(self, transfer_graph: TransferOpGraph) -> None: + pass + + @abstractmethod + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def shutdown(self) -> None: + pass + + +class TransferManagerIntraProcessHandle(TransferManagerHandleBase): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + self.transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) + self._is_ready = False + + def start(self) -> None: + self.transfer_manager.initialize_transfer_engine() + self.transfer_manager.start() + self._is_ready = True + + def is_ready(self) -> bool: + return self._is_ready + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.transfer_manager.submit(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_manager.wait(timeout) + + def shutdown(self) -> None: + self.transfer_manager.shutdown() + + +class TransferManagerInterProcessHandle(TransferManagerHandleBase): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + mp.set_start_method('spawn', force=True) + + self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + + self.command_parent_conn, self.command_child_conn = Pipe() + self.result_parent_conn, self.result_child_conn = Pipe() + + self.process: Optional[Process] = None + self.ready_event = Event() + + self._completed_results: List[Tuple[int, int]] = [] + + def _start_process(self) -> None: + if self.process is not None and self.process.is_alive(): + return + + self.process = Process( + target=self._process_worker, + args=(self.model_config, + self.cache_config, + self.command_child_conn, + self.result_child_conn, + self.gpu_register_port, + self.ready_event), + daemon=False + ) + self.process.start() + + def _process_worker(self, + model_config: ModelConfig, + cache_config: CacheConfig, + command_conn, + result_conn, + gpu_register_port: str, + ready_event) -> None: + try: + transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) + transfer_manager.initialize_transfer_engine() + transfer_manager.start() + ready_event.set() + while True: + try: + if command_conn.poll(timeout=0.0001): + request = command_conn.recv() + request_type = request.get('type') + if request_type == 'submit': + transfer_manager.submit(request['transfer_graph']) + else: + flexkv_logger.error(f"Unrecognized request type: {request_type}") + try: + finished_ops = transfer_manager.wait(0.0001) + if finished_ops: + result_conn.send(finished_ops) + except queue.Empty: + pass + except Exception as e: + flexkv_logger.error(f"Error in transfer manager process: {e}") + + except Exception as e: + flexkv_logger.error(f"Failed to initialize transfer manager process: {e}") + finally: + command_conn.close() + result_conn.close() + + def start(self) -> None: + self._start_process() + + def is_ready(self) -> bool: + return self.ready_event.is_set() + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.command_parent_conn.send({ + 'type': 'submit', + 'transfer_graph': transfer_graph + }) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + finished_ops: List[Tuple[int, int]] = [] + try: + if self.result_parent_conn.poll(timeout=timeout): + finished_ops += self.result_parent_conn.recv() + while self.result_parent_conn.poll(): + finished_ops += self.result_parent_conn.recv() + except EOFError: + pass + + return finished_ops + + def shutdown(self) -> None: + if self.process is not None: + self.process.terminate() + self.process.join(timeout=5.0) + if self.process.is_alive(): + self.process.kill() + self.process.join() + + self.command_parent_conn.close() + self.result_parent_conn.close() + + def __del__(self): + self.shutdown() + + +class TransferManagerHandle: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + use_separate_process: bool = True, + gpu_register_port: Optional[str] = None): + if gpu_register_port is None: + gpu_register_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + if use_separate_process: + self._handle: TransferManagerHandleBase = TransferManagerInterProcessHandle( + model_config, cache_config, gpu_register_port + ) + else: + self._handle: TransferManagerHandleBase = TransferManagerIntraProcessHandle( + model_config, cache_config, gpu_register_port + ) + + def start(self) -> None: + self._handle.start() + + def is_ready(self) -> bool: + return self._handle.is_ready() + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self._handle.submit(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self._handle.wait(timeout) + + def shutdown(self) -> None: + self._handle.shutdown() diff --git a/setup.py b/setup.py index f69e28b923..fcd0f97a34 100755 --- a/setup.py +++ b/setup.py @@ -2,13 +2,16 @@ import shutil import sys -from Cython.Build import cythonize + from setuptools import find_packages, setup from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension +def get_version(): + with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: + return f.read().strip() -build_dir = os.path.abspath("build") +build_dir = "build" os.makedirs(build_dir, exist_ok=True) # Check if we're in debug mode using environment variable @@ -25,17 +28,19 @@ "csrc/hash.cpp", "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", + "csrc/radix_tree.cpp", ] hpp_sources = [ "csrc/cache_utils.h", "csrc/tp_transfer_thread_group.h", "csrc/transfer_ssd.h", + "csrc/radix_tree.h", ] extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] extra_compile_args = ["-std=c++17"] -include_dirs = [os.path.join(build_dir, "include")] +include_dirs = [os.path.abspath(os.path.join(build_dir, "include"))] # Add rpath to find libraries at runtime lib_dir = os.path.join(build_dir, "lib") @@ -76,6 +81,8 @@ "flexkv/**/benchmark_*.py", "flexkv/benchmark/**/*.py", "flexkv/benchmark/test_kvmanager.py"] + # Import cython when debug is turned off. + from Cython.Build import cythonize cythonized_modules = cythonize( python_files, exclude=excluded_files, @@ -84,6 +91,7 @@ "boundscheck": False, "wraparound": False, "initializedcheck": False, + "profile": True, }, build_dir=build_dir, # Direct Cython to use the build directory ) @@ -125,16 +133,17 @@ def copy_shared_libraries(self): setup( name="flexkv", description="A global KV-Cache manager for LLM inference", - version="0.1.0", + version=get_version(), packages=find_packages(exclude=("benchmarks", "csrc", "examples", "tests")), package_data={ - "flexkv": ["lib/*.so", "lib/*.so.*"], + "flexkv": ["*.so", "lib/*.so", "lib/*.so.*"], }, include_package_data=True, install_requires=install_requires, ext_modules=ext_modules, # Now contains both C++ and Cython modules as needed cmdclass={ "build_ext": CustomBuildExt.with_options( + include_dirs=os.path.join(build_dir, "include"), # Include directory for xxhash no_python_abi_suffix=True, build_temp=os.path.join(build_dir, "temp"), # Temporary build files ) diff --git a/tests/conftest.py b/tests/conftest.py index e04cf78e50..2aba477a6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,3 @@ """ # Import fixtures from test_utils so pytest can discover them from test_utils import model_config, cache_config, test_config - -import multiprocessing as mp - -# Set the start method for multiprocessing to 'spawn' -# This ensures consistent behavior across different platforms -mp.set_start_method("spawn", force=True) diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index c9de75b720..fad6a20ea9 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -24,7 +24,7 @@ from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.memory_handle import TensorSharedHandle -from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVTaskEngine class FlexKVReplayEngine: @@ -113,7 +113,6 @@ def parse_config_event(self, event: Dict[str, Any]): ssd_kv_layout_type=self._parse_layout_type(cache_config_data['ssd_kv_layout_type']), remote_kv_layout_type=self._parse_layout_type(cache_config_data['remote_kv_layout_type']), use_gds=cache_config_data['use_gds'], - use_pinned_memory=False,#cache_config_data['use_pinned_memory'], # for local test remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], num_cpu_blocks=cache_config_data['num_cpu_blocks'], num_ssd_blocks=cache_config_data['num_ssd_blocks'], @@ -204,7 +203,7 @@ def create_kvmanager(self,): ) # Create KVManager - self.kvmanager = KVManager( + self.kvmanager = KVTaskEngine( model_config=self.model_config, cache_config=self.cache_config, gpu_layout=self.gpu_layout, @@ -274,10 +273,6 @@ def replay_wait_event(self, event: Dict[str, Any]): result = self.kvmanager.wait_for_graph_finished(task_ids) elif wait_type == "try_wait": result = self.kvmanager.try_wait(task_ids) - elif wait_type == "wait_at_layer_group": - result = self.kvmanager.wait_at_layer_group(task_ids[0], layer_group_id) - elif wait_type == "try_wait_at_layer_group": - result = self.kvmanager.try_wait_at_layer_group(task_ids, layer_group_id) else: raise ValueError(f"Unknown wait type: {wait_type}") successed_elements = [] diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 133dfb470c..e224d4305b 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -1,7 +1,7 @@ import random import pytest -import torch +import numpy as np from flexkv.cache.mempool import Mempool from flexkv.cache.cache_engine import CacheEngine @@ -16,6 +16,7 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngine: 'device_type': DeviceType.CPU, 'num_total_blocks': 64, 'tokens_per_block': 4, + 'evict_ratio': 0.05, } default_config_kwargs.update(param) return CacheEngine(**default_config_kwargs) @@ -23,11 +24,11 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngine: @pytest.mark.parametrize( "config, should_raise", [ - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), - ({'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'evict_ratio': 0.05, 'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), ] ) def test_config_init(config: dict, should_raise: bool): @@ -42,14 +43,14 @@ def test_mempool(): mempool = Mempool(num_total_blocks=64) assert mempool.num_free_blocks == 64 block_ids = mempool.allocate_blocks(16) - assert isinstance(block_ids, torch.Tensor) - assert block_ids.dtype == torch.int64 + assert isinstance(block_ids, np.ndarray) + assert block_ids.dtype == np.int64 assert block_ids.shape == (16,) assert mempool.num_free_blocks == 48 mempool.recycle_blocks(block_ids) assert mempool.num_free_blocks == 64 - block_ids = torch.cat([mempool.allocate_blocks(16), + block_ids = np.concatenate([mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16)]) @@ -63,21 +64,21 @@ def test_mempool(): empty_blocks = mempool.allocate_blocks(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): mempool.allocate_blocks(-1) - mempool.recycle_blocks(torch.tensor([], dtype=torch.int64)) + mempool.recycle_blocks(np.array([], dtype=np.int64)) assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int32)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int64)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([[1, 2, 3]], dtype=torch.int64)) + mempool.recycle_blocks(np.array([[1, 2, 3]], dtype=np.int64)) def test_reset(cache_engine: CacheEngine): cache_engine.reset() @@ -101,22 +102,22 @@ def test_reset(cache_engine: CacheEngine): [1, 10, 16, 32, 10000], ) def test_match_and_insert(cache_engine: CacheEngine, num_insert: int, seq_len: int): - base_token_ids = torch.randint(0, 10000, (seq_len, ), dtype=torch.int64) + base_token_ids = np.random.randint(0, 10000, (seq_len, ), dtype=np.int64) base_num_blocks = seq_len // cache_engine.tokens_per_block cache_engine.insert(SequenceMeta(token_ids=base_token_ids, tokens_per_block=cache_engine.tokens_per_block), - torch.arange(base_num_blocks, dtype=torch.int64), + np.arange(base_num_blocks, dtype=np.int64), is_ready=True) cur_cached_blocks = base_num_blocks for i in range(num_insert): prefix_ratio = random.random() prefix_len = int(len(base_token_ids)*prefix_ratio) num_prefix_blocks = prefix_len // cache_engine.tokens_per_block - token_ids = torch.cat([base_token_ids[:prefix_len], - torch.randint(10000 + i * seq_len, + token_ids = np.concatenate([base_token_ids[:prefix_len], + np.random.randint(10000 + i * seq_len, 10000 + (i+1) * seq_len, (seq_len-prefix_len, ), - dtype=torch.int64)]) + dtype=np.int64)]) insert_sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=cache_engine.tokens_per_block) match_result = cache_engine.match(insert_sequence_meta) @@ -125,11 +126,11 @@ def test_match_and_insert(cache_engine: CacheEngine, num_insert: int, seq_len: i assert match_result.last_ready_node is not None assert match_result.last_node is not None assert match_result.physical_blocks.shape == (num_prefix_blocks, ) - assert match_result.physical_blocks.dtype == torch.int64 + assert match_result.physical_blocks.dtype == np.int64 num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks cache_engine.insert(insert_sequence_meta, - torch.arange(num_insert_blocks, dtype=torch.int64), + np.arange(num_insert_blocks, dtype=np.int64), is_ready=True, match_result=match_result) cur_cached_blocks += num_insert_blocks @@ -150,7 +151,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): num_total_blocks = cache_engine.num_total_blocks tokens_per_block = cache_engine.tokens_per_block seq_blocks = 10 - token_ids = torch.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=torch.int64) + token_ids = np.random.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=np.int64) sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) physical_blocks = cache_engine.take(seq_blocks) @@ -159,7 +160,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): empty_blocks = cache_engine.take(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 with pytest.raises(ValueError): cache_engine.take(-1) @@ -168,7 +169,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) - assert physical_blocks2.dtype == torch.int64 + assert physical_blocks2.dtype == np.int64 cache_engine.recycle(physical_blocks2) @@ -193,22 +194,22 @@ def test_cleanup(cache_engine: CacheEngine): if cache_engine.tokens_per_block != 1: pytest.skip("tokens_per_block != 1") tokens_per_block = cache_engine.tokens_per_block - token_ids_list = [torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64), - torch.tensor([0, 1, 2, 3, 17, 15, 19, 20], dtype=torch.int64), - torch.tensor([0, 23, 22, 21], dtype=torch.int64)] + token_ids_list = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64), + np.array([0, 1, 2, 3, 17, 15, 19, 20], dtype=np.int64), + np.array([0, 23, 22, 21], dtype=np.int64)] sequence_meta_list = [SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) for token_ids in token_ids_list] num_insert_blocks0 = sequence_meta_list[0].num_blocks radixnode0 = cache_engine.insert(sequence_meta_list[0], - torch.arange(num_insert_blocks0, dtype=torch.int64), + np.arange(num_insert_blocks0, dtype=np.int64), is_ready=False) cache_engine.lock_node(radixnode0) radixnode0_size = radixnode0.size() match_result = cache_engine.match(sequence_meta_list[1]) num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks radixnode1 = cache_engine.insert(sequence_meta_list[1], - torch.arange(num_insert_blocks1, dtype=torch.int64), + np.arange(num_insert_blocks1, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode1) @@ -216,7 +217,7 @@ def test_cleanup(cache_engine: CacheEngine): match_result = cache_engine.match(sequence_meta_list[2]) num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks radixnode2 = cache_engine.insert(sequence_meta_list[2], - torch.arange(num_insert_blocks2, dtype=torch.int64), + np.arange(num_insert_blocks2, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode2) diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py new file mode 100644 index 0000000000..15fef43ec3 --- /dev/null +++ b/tests/test_cache_engine_accel.py @@ -0,0 +1,232 @@ +import random + +import pytest +import numpy as np + +from flexkv.cache.mempool import Mempool +from flexkv.cache.cache_engine import CacheEngineAccel +from flexkv.common.transfer import DeviceType +from flexkv.common.exceptions import InvalidConfigError, NotEnoughSpaceError +from flexkv.common.block import SequenceMeta + +@pytest.fixture +def cache_engine(request: pytest.FixtureRequest) -> CacheEngineAccel: + param = request.param if hasattr(request, 'param') else {} + default_config_kwargs = { + 'device_type': DeviceType.CPU, + 'num_total_blocks': 64, + 'tokens_per_block': 4, + 'evict_ratio': 0.05, + } + default_config_kwargs.update(param) + return CacheEngineAccel(**default_config_kwargs) + +@pytest.mark.parametrize( + "config, should_raise", + [ + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'evict_ratio': 0.05, 'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ] +) +def test_config_init(config: dict, should_raise: bool): + if should_raise: + with pytest.raises(InvalidConfigError) as e: + CacheEngineAccel(**config) + else: + engine = CacheEngineAccel(**config) + assert isinstance(engine, CacheEngineAccel) + +def test_mempool(): + mempool = Mempool(num_total_blocks=64) + assert mempool.num_free_blocks == 64 + block_ids = mempool.allocate_blocks(16) + assert isinstance(block_ids, np.ndarray) + assert block_ids.dtype == np.int64 + assert block_ids.shape == (16,) + assert mempool.num_free_blocks == 48 + mempool.recycle_blocks(block_ids) + assert mempool.num_free_blocks == 64 + + block_ids = np.concatenate([mempool.allocate_blocks(16), + mempool.allocate_blocks(16), + mempool.allocate_blocks(16), + mempool.allocate_blocks(16)]) + assert mempool.num_free_blocks == 0 + + with pytest.raises(NotEnoughSpaceError): + mempool.allocate_blocks(1) + + mempool.recycle_blocks(block_ids) + assert mempool.num_free_blocks == 64 + + empty_blocks = mempool.allocate_blocks(0) + assert empty_blocks.shape == (0, ) + assert empty_blocks.dtype == np.int64 + assert mempool.num_free_blocks == 64 + + with pytest.raises(ValueError): + mempool.allocate_blocks(-1) + + mempool.recycle_blocks(np.array([], dtype=np.int64)) + assert mempool.num_free_blocks == 64 + + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([[1, 2, 3]], dtype=np.int64)) + +def test_reset(cache_engine: CacheEngineAccel): + cache_engine.reset() + assert cache_engine.index.is_empty() + assert cache_engine.mempool.num_used_blocks == 0 + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 10000000, 'tokens_per_block': 1, 'device_type': DeviceType.CPU}, + {'num_total_blocks': 10000000, 'tokens_per_block': 16, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +@pytest.mark.parametrize( + "num_insert", + [100], +) +@pytest.mark.parametrize( + "seq_len", + [1, 10, 16, 32, 10000], +) +def test_match_and_insert(cache_engine: CacheEngineAccel, num_insert: int, seq_len: int): + base_token_ids = np.random.randint(0, 10000, (seq_len, ), dtype=np.int64) + base_num_blocks = seq_len // cache_engine.tokens_per_block + cache_engine.insert(SequenceMeta(token_ids=base_token_ids, + tokens_per_block=cache_engine.tokens_per_block), + np.arange(base_num_blocks, dtype=np.int64), + is_ready=True) + cur_cached_blocks = base_num_blocks + for i in range(num_insert): + prefix_ratio = random.random() + prefix_len = int(len(base_token_ids)*prefix_ratio) + num_prefix_blocks = prefix_len // cache_engine.tokens_per_block + token_ids = np.concatenate([base_token_ids[:prefix_len], + np.random.randint(10000 + i * seq_len, + 10000 + (i+1) * seq_len, + (seq_len-prefix_len, ), + dtype=np.int64)]) + insert_sequence_meta = SequenceMeta(token_ids=token_ids, + tokens_per_block=cache_engine.tokens_per_block) + match_result = cache_engine.match(insert_sequence_meta) + assert match_result.num_ready_matched_blocks == num_prefix_blocks + assert match_result.num_matched_blocks == num_prefix_blocks + + num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks + cache_engine.insert(insert_sequence_meta, + np.arange(num_insert_blocks, dtype=np.int64), + is_ready=True, + match_result=match_result) + cur_cached_blocks += num_insert_blocks + assert cache_engine.index.total_cached_blocks() == cur_cached_blocks + + match_result = cache_engine.match(insert_sequence_meta) + assert match_result.num_matched_blocks == insert_sequence_meta.num_blocks + assert match_result.num_ready_matched_blocks == insert_sequence_meta.num_blocks + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 100, 'tokens_per_block': 16, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +def test_take_and_recycle(cache_engine: CacheEngineAccel): + num_total_blocks = cache_engine.num_total_blocks + tokens_per_block = cache_engine.tokens_per_block + seq_blocks = 10 + token_ids = np.random.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=np.int64) + sequence_meta = SequenceMeta(token_ids=token_ids, + tokens_per_block=tokens_per_block) + physical_blocks = cache_engine.take(seq_blocks) + radixnode = cache_engine.insert(sequence_meta, physical_blocks, is_ready=True) + assert cache_engine.index.total_cached_blocks() == seq_blocks + + empty_blocks = cache_engine.take(0) + assert empty_blocks.shape == (0, ) + assert empty_blocks.dtype == np.int64 + + with pytest.raises(ValueError): + cache_engine.take(-1) + with pytest.raises(NotEnoughSpaceError): + cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) + + physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) + assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) + assert physical_blocks2.dtype == np.int64 + + cache_engine.recycle(physical_blocks2) + + cache_engine.lock_node(radixnode) + with pytest.raises(NotEnoughSpaceError): + cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) + cache_engine.cleanup(radixnode, radixnode.size()) + + physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) + assert physical_blocks.shape == (num_total_blocks, ) + assert cache_engine.index.total_cached_blocks() == 0 + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 100, 'tokens_per_block': 1, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +def test_cleanup(cache_engine: CacheEngineAccel): + if cache_engine.tokens_per_block != 1: + pytest.skip("tokens_per_block != 1") + tokens_per_block = cache_engine.tokens_per_block + token_ids_list = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64), + np.array([0, 1, 2, 3, 17, 15, 19, 20], dtype=np.int64), + np.array([0, 23, 22, 21], dtype=np.int64)] + sequence_meta_list = [SequenceMeta(token_ids=token_ids, + tokens_per_block=tokens_per_block) + for token_ids in token_ids_list] + num_insert_blocks0 = sequence_meta_list[0].num_blocks + radixnode0 = cache_engine.insert(sequence_meta_list[0], + np.arange(num_insert_blocks0, dtype=np.int64), + is_ready=False) + cache_engine.lock_node(radixnode0) + radixnode0_size = radixnode0.size() + match_result = cache_engine.match(sequence_meta_list[1]) + num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks + radixnode1 = cache_engine.insert(sequence_meta_list[1], + np.arange(num_insert_blocks1, dtype=np.int64), + match_result=match_result, + is_ready=False) + cache_engine.lock_node(radixnode1) + radixnode1_size = radixnode1.size() + match_result = cache_engine.match(sequence_meta_list[2]) + num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks + radixnode2 = cache_engine.insert(sequence_meta_list[2], + np.arange(num_insert_blocks2, dtype=np.int64), + match_result=match_result, + is_ready=False) + cache_engine.lock_node(radixnode2) + radixnode2_size = radixnode2.size() + total_insert_blocks = num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 + assert cache_engine.index.total_cached_blocks() == total_insert_blocks + assert cache_engine.index.total_unready_blocks() == total_insert_blocks + assert cache_engine.index.total_ready_blocks() == 0 + + cache_engine.cleanup(radixnode2, radixnode2_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 + + cache_engine.cleanup(radixnode1, radixnode1_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 + + cache_engine.cleanup(radixnode0, radixnode0_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index c01bee7aaa..f8e74d45e2 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -4,19 +4,67 @@ import pytest import torch +import multiprocessing as mp +from multiprocessing import Process, Pipe from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.request import KVResponseStatus +from flexkv.kvtask import KVTaskEngine from flexkv.kvmanager import KVManager +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.server.client import KVTPClient +from flexkv.common.debug import flexkv_logger # Import utilities from test_utils from test_utils import ( DEFAULT_MODEL_CONFIG, DEFAULT_CACHE_CONFIG, DEFAULT_TEST_CONFIG, generate_request_pair, verify_data, block_ids_2_slot_mapping, generate_gpu_blocks_with_ground_truth, skip_if_insufficient_gpus, - create_kvmanager_with_mode, create_gpu_kv_layout + create_gpu_kv_layout, GPUKVCacheVerifier ) +def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config, num_gpu_blocks, child_conn): + """Run tp_client process""" + try: + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) + + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Send GPU blocks back to main process via pipe if connection provided + if child_conn is not None: + print(f"[TP Client {tp_rank}] Converting {len(gpu_blocks_for_tp)} GPU blocks to TensorSharedHandle") + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks_for_tp] + child_conn.send(shared_gpu_blocks) + print(f"[TP Client {tp_rank}] Sent GPU blocks to main process via pipe") + child_conn.close() + + # Keep the process running + while True: + time.sleep(1) + except Exception as e: + if child_conn is not None: + child_conn.send(None) + child_conn.close() + +def shutdown_tp_client(tp_client_processes): + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) @pytest.mark.parametrize("model_config", [ {'tp_size': 1, 'dp_size': 1}, @@ -31,23 +79,18 @@ {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': False, 'ssd_cache_iouring_entries': 512}, {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': True, 'num_ssd_blocks': 256, 'num_remote_blocks': 512}, {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': True, - 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, + 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, ], indirect=True) @pytest.mark.parametrize("test_config", [ - {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4, 'use_server_client': False}, - {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4, 'use_server_client': True}, + {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, ], indirect=True) @pytest.mark.parametrize("flex_kv_layout_type", [ KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE, ]) def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type): - num_layers = model_config.num_layers - num_kv_heads = model_config.num_kv_heads - head_size = model_config.head_size tp_size = model_config.tp_size dp_size = model_config.dp_size - use_mla = model_config.use_mla tokens_per_block = cache_config.tokens_per_block num_cpu_blocks = cache_config.num_cpu_blocks @@ -64,7 +107,6 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) num_gpu_blocks = test_config["num_gpu_blocks"] block_per_request = test_config['requests_per_block'] initial_write_ratio = test_config['initial_write_ratio'] - use_server_client = test_config.get('use_server_client', False) num_requests = num_gpu_blocks // block_per_request @@ -74,50 +116,97 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) if enable_remote: pytest.skip("skip because enable_remote is not supported") - if use_server_client and dp_size > 1: - pytest.skip("skip because server-client mode is not supported for dp_size > 1 IN THIS TEST SCRIPT now") + if dp_size > 1: + #note that for now only dp_size=1 is supported + pytest.skip("skip because server-client mode is not ready for dp_size > 1") + + import uuid + gpu_register_port = f"ipc:///tmp/flexkv_gpu_{uuid.uuid4().hex[:8]}" + server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" + kvmanager = KVManager(model_config, cache_config, gpu_register_port, server_recv_port) + kvmanager.start() + + # Create pipes for each tp_client to send GPU blocks back + pipe_connections = [] + tp_client_processes = [] + + for tp_rank in range(tp_size): + parent_conn, child_conn = Pipe() + pipe_connections.append(parent_conn) - if use_server_client: - # In server-client mode, GPU blocks are created in tp_client processes - # We only need the layout for initialization + tp_client_process = Process( + target=run_tp_client, + args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks + tp_rank, child_conn), + daemon=True + ) + tp_client_processes.append(tp_client_process) + tp_client_process.start() + + # Collect GPU blocks from all tp_client processes + print(f"[Main Process] Waiting to receive GPU blocks from {tp_size} TP client processes...") + all_gpu_blocks = [] + + for tp_rank, parent_conn in enumerate(pipe_connections): + try: + shared_gpu_blocks = parent_conn.recv() + if shared_gpu_blocks is not None: + all_gpu_blocks.append(shared_gpu_blocks) + print(f"[Main Process] Received GPU blocks from TP client {tp_rank}") + else: + print(f"[Main Process] TP client {tp_rank} failed to create GPU blocks") + parent_conn.close() + except Exception as e: + print(f"[Main Process] Error receiving from TP client {tp_rank}: {e}") + + # Create GPUKVCacheVerifier with collected GPU blocks + if all_gpu_blocks and len(all_gpu_blocks) == tp_size: + print(f"[Main Process] Creating GPUKVCacheVerifier with GPU blocks from {len(all_gpu_blocks)} TP clients") + + # Get gpu_kv_layout from cache_config for GPUKVCacheVerifier gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) - gpu_blocks = None # Not used in server-client mode - dp_wise_gpu_blocks_gt = None # Not used in server-client mode + + gpu_kv_verifier = GPUKVCacheVerifier( + shared_gpu_blocks=all_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype + ) + print("[Main Process] GPUKVCacheVerifier created successfully") else: - # In direct mode, create GPU blocks in current process - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config) + print(f"[Main Process] Failed to collect GPU blocks from all TP clients. " + f"Got {len(all_gpu_blocks)} out of {tp_size}") + gpu_kv_verifier = None - kvmanager = create_kvmanager_with_mode(model_config, cache_config, gpu_kv_layout, gpu_blocks, use_server_client) + while not kvmanager.is_ready(): + time.sleep(1) + flexkv_logger.info("waiting for flexkv to be ready") - # put this after KVManager() num_remote_blocks = cache_config.num_remote_blocks - assert kvmanager.is_ready() - kvmanager.start() request_pairs = [generate_request_pair(i, block_per_request, num_gpu_blocks, tokens_per_block, dp_size) for i in range(num_requests)] initial_write_num = int(num_requests * initial_write_ratio) print("writing initial data...") + put_ids = [] for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, dp_id=dp_id, ) - kvmanager.wait_for_graph_finished(write_request) - if not use_server_client: - # In direct mode, update GPU blocks for verification - for gpu in range(dp_id * tp_size, (dp_id + 1) * tp_size): - for i in range(num_layers): - gpu_blocks[gpu][i][:, block_ids, :, :, :] = 0 + kvmanager.wait([write_request], completely=True) #corner case: input token length for put is less than tokens_per_block write_request = kvmanager.put_async( token_ids=torch.randint(0, 100, size=(8,), dtype=torch.int64), slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), + token_mask=None, dp_id=0, ) - kvmanager.wait_for_graph_finished(write_request) + kvmanager.wait([write_request], completely=True) #corner case: input token length is long enough, but the mask is less than tokens_per_block #my_mask = torch.zeros(16, dtype=torch.bool) #my_mask[0:8] = True @@ -134,44 +223,77 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) total_cache_miss = 0 running_get_requests = [] running_put_requests = [] + req_id2block_ids = {} + req_id2token_ids = {} + flexkv_id2req_id = {} start_time = time.time() print(f"the initial {initial_write_num} write done,performing mixed read/write...") for i in range(initial_write_num, num_requests): print(f"performing mixed read/write {i} / {num_requests} ...") read_idx = i - initial_write_num token_ids, block_ids, dp_id = request_pairs[read_idx] - request_id = kvmanager.get_async( + slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) + request_id, _ = kvmanager.get_match( token_ids=token_ids, - slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), layer_granularity=-1, + token_mask=None, dp_id=dp_id, ) + kvmanager.launch(request_id, slot_mapping) + flexkv_id2req_id[request_id] = read_idx running_get_requests.append(request_id) + req_id2block_ids[request_id] = block_ids + req_id2token_ids[request_id] = token_ids token_ids, block_ids, dp_id = request_pairs[i] + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) request_id = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, dp_id=dp_id, ) + flexkv_id2req_id[request_id] = i + print(f"write flexkv request_id {request_id} to req_id {i}") running_put_requests.append(request_id) min_block_num = min(num_cpu_blocks, num_gpu_blocks) if (len(running_get_requests) + len(running_put_requests) >= min_block_num // block_per_request - 2 or i % initial_write_num == initial_write_num - 1 or i == num_requests - 1): if len(running_put_requests) > 0: - kvmanager.wait_for_graph_finished(running_put_requests) + kvmanager.wait(running_put_requests, completely=True) if len(running_get_requests) > 0: - return_masks = kvmanager.wait(running_get_requests) - for return_mask in return_masks.values(): - total_cache_hit += return_mask.sum() - total_cache_miss += len(return_mask) - return_mask.sum() + return_results = kvmanager.wait(running_get_requests, completely=True) + if gpu_kv_verifier is not None: + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + valid_fetched_tokens = kvresponse.return_mask.sum().item() // \ + tokens_per_block * tokens_per_block + token_ids = req_id2token_ids[req_id] + block_ids = req_id2block_ids[req_id] + assert gpu_kv_verifier.verify_kv_blocks( + token_ids[:valid_fetched_tokens], + block_ids[:valid_fetched_tokens//tokens_per_block]) + for kvresponse in return_results.values(): + assert kvresponse.status == KVResponseStatus.SUCCESS + total_cache_hit += kvresponse.return_mask.sum().item() + total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() running_get_requests = [] running_put_requests = [] if len(running_get_requests) > 0: - kvmanager.wait(running_get_requests) + return_results = kvmanager.wait(running_get_requests, completely=True) + if gpu_kv_verifier is not None: + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + token_ids = req_id2token_ids[req_id] + block_ids = req_id2block_ids[req_id] + assert gpu_kv_verifier.verify_kv_blocks( + token_ids[:valid_fetched_tokens], + block_ids[:valid_fetched_tokens//tokens_per_block]) running_get_requests = [] if len(running_put_requests) > 0: - kvmanager.wait_for_graph_finished(running_put_requests) + kvmanager.wait(running_put_requests, completely=True) running_put_requests = [] print("mixed read/write done") end_time = time.time() @@ -182,12 +304,12 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) enable_ssd and num_ssd_blocks >= num_gpu_blocks or \ enable_remote and num_remote_blocks >= num_gpu_blocks: assert total_cache_miss == 0 + shutdown_tp_client(tp_client_processes) kvmanager.shutdown() - if total_cache_miss == 0 and not use_server_client: - # Only verify data in direct mode - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, num_kv_heads, tp_size, dp_size, num_layers, use_mla) + # Only verify data in direct mode + # verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, num_kv_heads, tp_size, dp_size, num_layers, use_mla) + if total_cache_miss == 0: + return elif total_cache_miss > 0: - print(f"verify skipped, because of total_cache_miss={total_cache_miss}>0") - elif use_server_client: - print("verify skipped in server-client mode (verification happens in tp_client processes)") + print(f"verify skipped, because of total_cache_miss={total_cache_miss} > 0") diff --git a/tests/test_transfer_engine.py b/tests/test_transfer_engine.py deleted file mode 100644 index 73ee107864..0000000000 --- a/tests/test_transfer_engine.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -Transfer Engine Unit Tests - -This module contains comprehensive unit tests for the TransferEngine component, -which handles data transfers between different storage tiers (GPU, CPU, SSD). - -Test Functions Overview: -1. test_gpu_cpu_round_trip: Tests round-trip data transfers between GPU and CPU - - Parameterized by: tp_size, dp_size, num_gpu_blocks, transfer_block_num - - Validates data consistency after GPU->CPU->GPU transfers - -2. test_ssd_round_trip: Tests round-trip data transfers involving SSD storage - - Parameterized by: num_gpu_blocks, transfer_block_num, enable_ssd_cache - - Validates data consistency after GPU->CPU->SSD->CPU->GPU transfers - -3. test_concurrent_mixed_transfers: Tests multiple concurrent read/write transfers - - Parameterized by: num_concurrent_transfers, blocks_per_transfer, include_ssd - - Validates correctness of mixed read/write transfer graphs running concurrently - -usage example: - python -m pytest tests/test_transfer_engine.py::test_gpu_cpu_round_trip -v --tb=short -Each test validates both transfer completion and data correctness to ensure -the TransferEngine maintains data integrity across all transfer operations. -""" - -import os -import time -import tempfile -from typing import List, Dict, Tuple -import multiprocessing as mp -from contextlib import suppress - -import pytest -import torch - -from flexkv.cache.transfer_pattern import ( - create_read_graph_cpu_storage, - create_write_graph_cpu_storage, -) -from flexkv.common.config import ModelConfig, CacheConfig -from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.common.transfer import DeviceType -from flexkv.storage.storage_engine import StorageEngine -from flexkv.transfer.transfer_engine import TransferEngine - -# Import utilities from test_utils -from test_utils import ( - wait_for_transfer_completion, - skip_if_no_cuda, - skip_if_insufficient_gpus, - generate_gpu_blocks_with_ground_truth, - verify_data -) - -@pytest.mark.parametrize("tp_size,dp_size", [(1, 1), (2, 1), (2, 2)]) -@pytest.mark.parametrize("num_gpu_blocks", [128]) -@pytest.mark.parametrize("transfer_block_num", [16]) -@pytest.mark.parametrize("use_mla", [False, True]) -@pytest.mark.parametrize("underlying_layout_type", [KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE]) -def test_gpu_cpu_round_trip(model_config, - cache_config, - test_config, - tp_size, - dp_size, - num_gpu_blocks, - transfer_block_num, - use_mla, - underlying_layout_type): - """ - Test round-trip data transfers between GPU and CPU - - This test validates: - 1. GPU -> CPU transfer correctness - 2. CPU -> GPU transfer correctness - 3. Round-trip data consistency (GPU -> CPU -> GPU) - - Parameterized by: - - tp_size, dp_size: Tensor and data parallelism configurations - - num_gpu_blocks: Number of GPU blocks to test with - - transfer_block_num: Number of blocks to transfer in each operation - """ - total_gpus = tp_size * dp_size - skip_if_insufficient_gpus(total_gpus) - - if transfer_block_num > num_gpu_blocks: - pytest.skip(f"transfer_block_num ({transfer_block_num}) > num_gpu_blocks ({num_gpu_blocks})") - - # Update model config - model_config.use_mla = use_mla - model_config.tp_size = tp_size - model_config.dp_size = dp_size - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - cache_config.cpu_kv_layout_type = underlying_layout_type - # Setup configurations - cache_config.enable_ssd = False - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) for i in range(total_gpus)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle - ) - transfer_engine.start() - - # Test each DP group separately - for dp_id in range(dp_size): - gpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - cpu_block_ids = torch.arange(dp_id * transfer_block_num, (dp_id + 1) * transfer_block_num, dtype=torch.int64) - - # Step 1: GPU -> CPU transfer - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=torch.tensor([], dtype=torch.int64), - gpu_device_id=dp_id * tp_size, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(dp_id) - transfer_engine.submit_transfer_graph(write_graph) - - # Wait for write completion - assert wait_for_transfer_completion(transfer_engine, [write_graph.graph_id]), \ - f"GPU->CPU transfer failed for DP group {dp_id}" - - # Clear GPU blocks for read test - for tp_id in range(tp_size): - global_gpu_id = dp_id * tp_size + tp_id - for layer_id in range(model_config.num_layers): - gpu_blocks[global_gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Step 2: CPU -> GPU transfer - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=torch.tensor([], dtype=torch.int64), - gpu_device_id=dp_id * tp_size, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(dp_id) - transfer_engine.submit_transfer_graph(read_graph) - # Wait for read completion - assert wait_for_transfer_completion(transfer_engine, [read_graph.graph_id]), \ - f"CPU->GPU transfer failed for DP group {dp_id}" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - # Cleanup - transfer_engine.shutdown() - - -@pytest.mark.parametrize("num_gpu_blocks", [64, 128]) -@pytest.mark.parametrize("transfer_block_num", [8, 16]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("iouring_entries", [0, 512]) -def test_ssd_round_trip(model_config, - cache_config, - test_config, - num_gpu_blocks, - transfer_block_num, - use_mla, - iouring_entries): - """ - Test round-trip data transfers involving SSD storage - - This test validates: - 1. GPU -> CPU -> SSD transfer chain - 2. SSD -> CPU -> GPU transfer chain - 3. Full round-trip data consistency - - Parameterized by: - - num_gpu_blocks: Number of GPU blocks to test with - - transfer_block_num: Number of blocks to transfer - """ - skip_if_no_cuda() - - if transfer_block_num > num_gpu_blocks: - pytest.skip(f"transfer_block_num ({transfer_block_num}) > num_gpu_blocks ({num_gpu_blocks})") - - # Setup configurations - cache_config.enable_ssd = True - cache_config.ssd_cache_iouring_entries = iouring_entries - model_config.use_mla = use_mla - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - if (model_config.tp_size * model_config.dp_size) > 1: - pytest.skip("SSD transfer test is not supported for multi-GPU") - - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(model_config.tp_size * model_config.dp_size)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - ssd_handle = storage_engine.get_storage_handle(DeviceType.SSD) - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle - ) - transfer_engine.start() - # Prepare transfer block IDs - gpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - cpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - ssd_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - - # Step 1: GPU -> CPU -> SSD write - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(0) - transfer_engine.submit_transfer_graph(write_graph) - - # Wait for write completion - assert wait_for_transfer_completion(transfer_engine, [write_graph.graph_id]), \ - "GPU->CPU->SSD write transfer failed" - - # Clear GPU blocks for read test - for gpu_id in range(model_config.tp_size * model_config.dp_size): - for layer_id in range(model_config.num_layers): - gpu_blocks[gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Step 2: SSD -> CPU -> GPU read - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(0) - transfer_engine.submit_transfer_graph(read_graph) - - # Wait for read completion - assert wait_for_transfer_completion(transfer_engine, [read_graph.graph_id]), \ - "SSD->CPU->GPU read transfer failed" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - - # Cleanup - transfer_engine.shutdown() - - -@pytest.mark.parametrize("num_concurrent_transfers", [4]) -@pytest.mark.parametrize("blocks_per_transfer", [16]) -@pytest.mark.parametrize("include_ssd", [True, False]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("iouring_entries", [0, 512]) -def test_concurrent_mixed_transfers(model_config, - cache_config, - test_config, - num_concurrent_transfers, - blocks_per_transfer, - include_ssd, - use_mla, - iouring_entries): - """ - Test multiple concurrent read/write transfers - - This test validates: - 1. Multiple write transfers running concurrently - 2. Multiple read transfers running concurrently - 3. Mixed read/write transfers running concurrently - 4. Data correctness across all concurrent operations - - Parameterized by: - - num_concurrent_transfers: Number of concurrent transfer graphs - - blocks_per_transfer: Number of blocks per transfer operation - - include_ssd: Whether to include SSD in transfer operations - """ - model_config.use_mla = use_mla - skip_if_no_cuda() - - if (model_config.tp_size * model_config.dp_size) > 1: - pytest.skip("Concurrent transfer test is not supported for multi-GPU") - - total_blocks_needed = num_concurrent_transfers * blocks_per_transfer * 2 # For both read and write - num_gpu_blocks = max(128, total_blocks_needed) - - cache_config.num_cpu_blocks = num_gpu_blocks - cache_config.num_ssd_blocks = num_gpu_blocks - cache_config.ssd_cache_iouring_entries = iouring_entries - - # Setup configurations - cache_config.enable_ssd = include_ssd - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(model_config.tp_size * model_config.dp_size)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - ssd_handle = storage_engine.get_storage_handle(DeviceType.SSD) if include_ssd else None - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle - ) - - transfer_engine.start() - # Create concurrent write transfers - write_graphs = [] - - for i in range(num_concurrent_transfers): - start_block = i * blocks_per_transfer - end_block = start_block + blocks_per_transfer - - gpu_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) - cpu_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) - ssd_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) \ - if include_ssd else torch.tensor([], dtype=torch.int64) - - - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(0) - write_graphs.append(write_graph) - - # Submit all write transfers - for graph in write_graphs: - transfer_engine.submit_transfer_graph(graph) - - # Wait for all writes to complete - write_graph_ids = [graph.graph_id for graph in write_graphs] - assert wait_for_transfer_completion(transfer_engine, write_graph_ids, max_wait_time=20.0), \ - "Concurrent write transfers failed to complete" - - # Clear GPU blocks for read test - for gpu_id in range(model_config.tp_size * model_config.dp_size): - gpu_block_ids = torch.arange(0, (num_concurrent_transfers + 1) * blocks_per_transfer, dtype=torch.int64) - for layer_id in range(model_config.num_layers): - gpu_blocks[gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Create concurrent read transfers (using different GPU blocks) - read_graphs = [] - - for i in range(num_concurrent_transfers): - gpu_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) - cpu_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) - ssd_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) \ - if include_ssd else torch.tensor([], dtype=torch.int64) - - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(0) - read_graphs.append(read_graph) - - # Submit all read transfers - for graph in read_graphs: - transfer_engine.submit_transfer_graph(graph) - - # Wait for all reads to complete - read_graph_ids = [graph.graph_id for graph in read_graphs] - assert wait_for_transfer_completion(transfer_engine, read_graph_ids, max_wait_time=20.0), \ - "Concurrent read transfers failed to complete" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - - # Cleanup - transfer_engine.shutdown() - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_utils.py b/tests/test_utils.py index a7e2f166c5..93541b612b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,16 @@ import time import os import shutil -from typing import List, Dict, Tuple -from multiprocessing import Process - +from typing import List, Dict, Tuple, Optional, Union +from multiprocessing import Process, Pipe, Queue +import pickle +import multiprocessing as mp import pytest import torch from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.memory_handle import TensorSharedHandle # Default configurations @@ -36,9 +38,8 @@ 'remote_file_prefix': "remote_cache", 'use_gds': False, 'enable_trace': False, - 'use_pinned_memory': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], - 'ssd_cache_iouring_entries': 0, + 'ssd_cache_iouring_entries': 32, 'remote_cache_path': ["remote_cache1", "remote_cache2"], 'remote_config_custom': { "pcfs_fsid": "f_l91fz6", @@ -47,9 +48,9 @@ "pcfs_parent_nodeid": 144115188075855883 # Using transfer engine value for consistency }, 'use_ce_transfer_h2d': False, - 'use_ce_transfer_d2h': True, - 'transfer_sms_h2d': 4, - 'transfer_sms_d2h': 4, + 'use_ce_transfer_d2h': False, + 'transfer_sms_h2d': 8, + 'transfer_sms_d2h': 8, } DEFAULT_TEST_CONFIG = { @@ -292,7 +293,7 @@ def __init__(self, model_config, cache_config, gpu_kv_layout, gpu_blocks): tp_client_process = Process( target=KVManagerServerClient._run_tp_client, args=(self.dp_client.dp_client_id, tp_rank, device_id, self.server_recv_port, - model_config.num_layers, str(model_config.dtype), + model_config.num_layers, str(model_config.dtype), list(gpu_kv_layout.kv_shape[1:]), model_config.use_mla), daemon=True ) @@ -340,7 +341,7 @@ def _run_tp_client(dp_client_id, tp_rank, device_id, server_recv_port, num_layer from flexkv.server.client import KVTPClient from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Convert dtype string back to torch dtype if dtype_str == "torch.float16": dtype = torch.float16 @@ -449,11 +450,310 @@ def shutdown(self): print("KVManagerServerClient shutdown complete") -def create_kvmanager_with_mode(model_config, cache_config, gpu_kv_layout, gpu_blocks, use_server_client=False): - """Create KVManager with optional server-client mode""" - if use_server_client: - print("Using server-client mode") - return KVManagerServerClient(model_config, cache_config, gpu_kv_layout, gpu_blocks) - else: - from flexkv.kvmanager import KVManager - return KVManager(model_config, cache_config, gpu_kv_layout, gpu_blocks) +class GPUKVCacheVerifier: + def __init__(self, + shared_gpu_blocks: Union[List[torch.Tensor], List[TensorSharedHandle], List[List[TensorSharedHandle]]], + gpu_kv_layout: KVCacheLayout, + tp_size: int, + tokens_per_block: int, + dtype: torch.dtype)->None: + self.gpu_kv_layout = gpu_kv_layout + self.num_layers = gpu_kv_layout.num_layer + # we have to map the exported gpu blocks into the virtual space of current process + if isinstance(shared_gpu_blocks[0], torch.Tensor): + self.gpu_blocks = shared_gpu_blocks + elif isinstance(shared_gpu_blocks[0], TensorSharedHandle): + self.gpu_blocks = [wrapper.get_tensor() for wrapper in shared_gpu_blocks] + else: + imported_gpu_blocks = [] + for handles_in_one_gpu in shared_gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_gpu_blocks.append(blocks_in_one_gpu) + self.gpu_blocks = imported_gpu_blocks + self.gpu_block_num = gpu_kv_layout.num_block + self.tp_size = tp_size + self.is_mla = gpu_kv_layout.is_mla + self.tokens_per_block = tokens_per_block + self.dtype = dtype + + + def hash_all_values(self, layer_id, kv_id, token_ids, head_id): + base_hash = hash((layer_id, kv_id, head_id)) + + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + + token_hash = 0 + prime = 31 + for i, token_id in enumerate(token_ids): + token_hash += (token_id * (prime ** i)) % (2**31 - 1) + + combined_hash = (base_hash + token_hash) % (2**31 - 1) + + normalized_value = (combined_hash % 1000000) / 1000000.0 + + return torch.tensor(normalized_value, dtype=self.dtype).item() + + def fill_gpu_blocks(self, token_ids, block_ids): + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + # Ensure token_ids is in tensor format + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for layer_id in range(self.num_layers): + kv_num = 2 if not self.is_mla else 1 + for kv_id in range(kv_num): + for tp_id in range(self.tp_size): + if isinstance(self.gpu_blocks[0], list): + # multiple gpu:gpu_blocks[tp_id][layer_id] + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + else: + # single gpu:gpu_blocks[layer_id] + gpu_tensor = self.gpu_blocks[layer_id] + + for head_id in range(self.gpu_kv_layout.num_head): + actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id + + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + hash_value = self.hash_all_values(layer_id, + kv_id, + token_ids[start_token_idx:end_token_idx], + actual_head_id) + # GPU tensor dim:[kv_dim, num_block, tokens_per_block, num_head, head_size] + gpu_tensor[kv_id, block_id, :, head_id, :] = hash_value + + def verify_kv_blocks(self, token_ids, block_ids)->bool: + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + verification_passed = True + errors = [] + + for layer_id in range(self.num_layers): + kv_num = 2 if not self.is_mla else 1 + for kv_id in range(kv_num): + for tp_id in range(self.tp_size): + if isinstance(self.gpu_blocks[0], list): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + else: + gpu_tensor = self.gpu_blocks[layer_id] + + for head_id in range(self.gpu_kv_layout.num_head): + actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + expected_hash_value = self.hash_all_values(layer_id, kv_id, + token_ids[start_token_idx:end_token_idx], + actual_head_id) + + actual_values = gpu_tensor[kv_id, block_id, :, head_id, :] + + if not torch.allclose(actual_values, + torch.full_like(actual_values, expected_hash_value), + rtol=1e-5, atol=1e-6): + verification_passed = False + errors.append( + f"Mismatch at layer={layer_id}, kv={kv_id}, tp={tp_id}, " + f"head={head_id}, block={block_id}: " + f"expected={expected_hash_value}, got={actual_values[0, 0].item()}" + ) + + if not verification_passed: + print(f"Verification failed with {len(errors)} errors:") + for error in errors[:10]: + print(f" {error}") + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more errors") + else: + print("KV blocks verification passed!") + + return verification_passed + + +def gpu_blocks_worker_process(conn, model_config, cache_config, gpu_kv_layout): + try: + print(f"[Worker Process {os.getpid()}] Starting to create GPU blocks...") + + # Create GPU blocks in subprocess + gpu_blocks = [] + for layer_id in range(model_config.num_layers): + # LAYERWISE format: [kv_dim, num_block, tokens_per_block, num_head, head_size] + kv_dim = 2 if not model_config.use_mla else 1 + gpu_tensor = torch.zeros( + kv_dim, + gpu_kv_layout.num_block, + gpu_kv_layout.tokens_per_block, + gpu_kv_layout.num_head, + gpu_kv_layout.head_size, + dtype=model_config.dtype, + device='cuda:0' if torch.cuda.is_available() else 'cpu' + ) + gpu_blocks.append(gpu_tensor) + + print(f"[Worker Process {os.getpid()}] Successfully created {len(gpu_blocks)} GPU blocks") + + # Convert to TensorSharedHandle + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks] + print(f"[Worker Process {os.getpid()}] Successfully converted to {len(shared_gpu_blocks)} TensorSharedHandles") + + # Send to main process via pipe + conn.send(shared_gpu_blocks) + print(f"[Worker Process {os.getpid()}] Sent TensorSharedHandle list to main process via pipe") + + #while True: + # time.sleep(1) + conn.close() + + except Exception as e: + print(f"[Worker Process {os.getpid()}] Error occurred: {e}") + conn.send(None) + conn.close() + + +# Usage examples +def example_usage_gpu_kv_cache_verifier(): + """Demonstrates three ways to initialize GPUKVCacheVerifier""" + import torch + from flexkv.common.config import ModelConfig, CacheConfig + from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType + from flexkv.common.memory_handle import TensorSharedHandle + + # Create example configurations + model_config = ModelConfig( + num_layers=2, + num_kv_heads=8, + head_size=64, + use_mla=False, + dtype=torch.float16, + tp_size=1, + dp_size=1 + ) + + cache_config = CacheConfig( + tokens_per_block=16 + ) + + # Create GPU KV layout + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERWISE, + num_layer=model_config.num_layers, + num_block=64, # Assume 64 blocks + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla + ) + + # Create mock GPU blocks + gpu_blocks = [] + for layer_id in range(model_config.num_layers): + # LAYERWISE format: [kv_dim, num_block, tokens_per_block, num_head, head_size] + kv_dim = 2 if not model_config.use_mla else 1 + gpu_tensor = torch.zeros( + kv_dim, + gpu_kv_layout.num_block, + gpu_kv_layout.tokens_per_block, + gpu_kv_layout.num_head, + gpu_kv_layout.head_size, + dtype=model_config.dtype, + device='cuda:0' if torch.cuda.is_available() else 'cpu' + ) + gpu_blocks.append(gpu_tensor) + + print("=== Method 1: Direct Tensor List ===") + verifier1 = GPUKVCacheVerifier( + shared_gpu_blocks=gpu_blocks, # Pass tensor list directly + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + ) + + print("=== Method 2: Using TensorSharedHandle (Multi-process version) ===") + mp.set_start_method('spawn') + # Create pipe for inter-process communication + parent_conn, child_conn = Pipe() + print(f"[Main Process {os.getpid()}] Successfully created pipe connection") + + # Start worker process to create GPU blocks and TensorSharedHandle + worker_process = Process( + target=gpu_blocks_worker_process, + args=(child_conn, model_config, cache_config, gpu_kv_layout) + ) + + print(f"[Main Process {os.getpid()}] Starting worker process...") + worker_process.start() + + # Wait to receive TensorSharedHandle created by worker process + print(f"[Main Process {os.getpid()}] Waiting to receive results from worker process...") + shared_gpu_blocks = parent_conn.recv() + + # Wait for worker process to complete + + + if shared_gpu_blocks is None: + raise RuntimeError("Worker process failed to create GPU blocks") + + print(f"[Main Process {os.getpid()}] Successfully received {len(shared_gpu_blocks)} TensorSharedHandles") + verifier2 = GPUKVCacheVerifier( + shared_gpu_blocks=shared_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + ) + + # Prepare test data - Note: now hash is calculated per block + token_ids = torch.randint(0, 1000, (32,), dtype=torch.int64) # 32 tokens (2 blocks) + block_ids = torch.tensor([0, 1], dtype=torch.int64) # Use blocks 0 and 1 + + print(f"Token IDs shape: {token_ids.shape}") + print(f"Block IDs: {block_ids}") + print(f"Tokens per block: {cache_config.tokens_per_block}") + + # Test method 1 + print("\n=== Testing Method 1 (Direct Tensor) ===") + print("Starting to fill GPU blocks...") + verifier1.fill_gpu_blocks(token_ids, block_ids) + print("Filling completed!") + + print("Starting data verification...") + is_valid1 = verifier1.verify_kv_blocks(token_ids, block_ids) + print(f"Verification result: {'PASSED' if is_valid1 else 'FAILED'}") + + # Test method 2 + print("\n=== Testing Method 2 (SharedHandle) ===") + print("Starting to fill GPU blocks...") + verifier2.fill_gpu_blocks(token_ids, block_ids) + print("Filling completed!") + + print("Starting data verification...") + is_valid2 = verifier2.verify_kv_blocks(token_ids, block_ids) + print(f"Verification result: {'PASSED' if is_valid2 else 'FAILED'}") + + # Demonstrate hash calculation changes: now each block has independent hash values + print("\n=== Hash Calculation Demo ===") + for block_idx, block_id in enumerate(block_ids): + start_idx = block_idx * cache_config.tokens_per_block + end_idx = start_idx + cache_config.tokens_per_block + block_tokens = token_ids[start_idx:end_idx] + hash_value = verifier1.hash_all_values(0, 0, block_tokens, 0) + print(f"Block {block_id} tokens: {block_tokens.tolist()[:5]}... -> hash: {hash_value:.6f}") + worker_process.join() + parent_conn.close() + return verifier1, token_ids, block_ids + +if __name__ == "__main__": + example_usage_gpu_kv_cache_verifier()