From ba27666d63c5a7893ccc861e4c8503e1d5770872 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 30 Oct 2025 10:10:26 +0000 Subject: [PATCH 1/6] blockfirst ssd io --- csrc/transfer_ssd.cpp | 74 +++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/csrc/transfer_ssd.cpp b/csrc/transfer_ssd.cpp index b36c55741e..4aa72c6888 100644 --- a/csrc/transfer_ssd.cpp +++ b/csrc/transfer_ssd.cpp @@ -39,7 +39,8 @@ static void _transfer_iouring_impl( int64_t cpu_layer_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, - int num_files_per_device, bool is_read, bool is_mla) { + int num_files_per_device, bool is_read, bool is_mla, + bool enable_block_first_transfer) { int num_blocks = end_block - start_block; int rc; @@ -53,6 +54,39 @@ static void _transfer_iouring_impl( int fd = fd_list[ssd_block_id % num_files_per_device]; ssd_block_id /= num_files_per_device; // block id in single file + if (enable_block_first_transfer) { + int layers_chunk_size_in_bytes = + cpu_layer_stride_in_bytes * (end_layer - start_layer); + int cpu_layers_chunk_offset = start_layer * cpu_layer_stride_in_bytes; + int ssd_layers_chunk_offset = start_layer * ssd_layer_stride_in_bytes; + void *cpu_block_ptr = reinterpret_cast(cpu_tensor_ptr) + + block_stride_in_bytes * cpu_block_id + + cpu_layers_chunk_offset; + int ssd_block_offset = + ssd_block_id * block_stride_in_bytes + ssd_layers_chunk_offset; + + ssize_t bytes_transfer = 0; + if (is_read) { + rc = iouring.prep_read(fd, cpu_block_ptr, layers_chunk_size_in_bytes, + ssd_block_offset); + if (rc < 0) { + bytes_transfer = pread(fd, cpu_block_ptr, layers_chunk_size_in_bytes, + ssd_block_offset); + } + } else { + rc = iouring.prep_write(fd, cpu_block_ptr, layers_chunk_size_in_bytes, + ssd_block_offset); + if (rc < 0) { + bytes_transfer = pwrite(fd, cpu_block_ptr, layers_chunk_size_in_bytes, + ssd_block_offset); + } + } + if (bytes_transfer && (bytes_transfer != layers_chunk_size_in_bytes)) { + throw std::runtime_error("Failed to transfer block"); + } + continue; + } + for (int lid = start_layer; lid < end_layer; lid++) { int64_t ssd_k_block_offset = ssd_block_id * block_stride_in_bytes + lid * ssd_layer_stride_in_bytes; @@ -71,15 +105,15 @@ static void _transfer_iouring_impl( rc = iouring.prep_read(fd, cpu_k_block_ptr, chunk_size_in_bytes, ssd_k_block_offset); if (rc < 0) { - bytes_transfer = pread(fd, cpu_k_block_ptr, chunk_size_in_bytes, - ssd_k_block_offset); + bytes_transfer = pread(fd, cpu_k_block_ptr, chunk_size_in_bytes, + ssd_k_block_offset); } } else { rc = iouring.prep_write(fd, cpu_k_block_ptr, chunk_size_in_bytes, ssd_k_block_offset); if (rc < 0) { - bytes_transfer = pwrite(fd, cpu_k_block_ptr, chunk_size_in_bytes, - ssd_k_block_offset); + bytes_transfer = pwrite(fd, cpu_k_block_ptr, chunk_size_in_bytes, + ssd_k_block_offset); } } @@ -96,20 +130,20 @@ static void _transfer_iouring_impl( rc = iouring.prep_read(fd, cpu_v_block_ptr, chunk_size_in_bytes, ssd_v_block_offset); if (rc < 0) { - bytes_transfer = pread(fd, cpu_v_block_ptr, chunk_size_in_bytes, - ssd_v_block_offset); + bytes_transfer = pread(fd, cpu_v_block_ptr, chunk_size_in_bytes, + ssd_v_block_offset); } } else { rc = iouring.prep_write(fd, cpu_v_block_ptr, chunk_size_in_bytes, ssd_v_block_offset); if (rc < 0) { - bytes_transfer = pwrite(fd, cpu_v_block_ptr, chunk_size_in_bytes, - ssd_v_block_offset); + bytes_transfer = pwrite(fd, cpu_v_block_ptr, chunk_size_in_bytes, + ssd_v_block_offset); } } if (bytes_transfer && (bytes_transfer != chunk_size_in_bytes)) { - throw std::runtime_error("Failed to transfer K block"); + throw std::runtime_error("Failed to transfer K block"); } } // end layer loop } // end block loop @@ -181,10 +215,10 @@ static void _transfer_single_thread_impl( // NOTE that we may also use other techniques such as // AIO, O_DIRECT, and etc to improve the performance void transfer_kv_blocks_ssd( - SSDIOCTX &ioctx, - const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, - const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids, - int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, + SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, + int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids, + const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes, + int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, // in single file int64_t ssd_kv_stride_in_bytes, // in single file int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read, @@ -212,6 +246,13 @@ void transfer_kv_blocks_ssd( cpu_block_id_ptr, ssd_block_id_ptr, num_blocks, num_devices, round_robin, cpu_blocks_partition, ssd_blocks_partition); + const bool cpu_is_block_first = + block_stride_in_bytes > cpu_layer_stride_in_bytes; + const bool ssd_is_block_first = + block_stride_in_bytes > ssd_layer_stride_in_bytes; + const bool enable_block_first_transfer = + cpu_is_block_first && ssd_is_block_first; + std::vector threads; std::vector> futures; for (int d = 0; d < num_devices; d++) { @@ -228,13 +269,12 @@ void transfer_kv_blocks_ssd( if (start_block < end_block) { if (iouring.enabled()) { _transfer_iouring_impl( - iouring, fds[d], - cpu_blocks_partition[d], ssd_blocks_partition[d], + iouring, fds[d], cpu_blocks_partition[d], ssd_blocks_partition[d], start_layer, end_layer, start_block, end_block, cpu_tensor_ptr, cpu_layer_stride_in_bytes, ssd_layer_stride_in_bytes, cpu_kv_stride_in_bytes, ssd_kv_stride_in_bytes, chunk_size_in_bytes, block_stride_in_bytes, num_files_per_device, - is_read, is_mla); + is_read, is_mla, enable_block_first_transfer); continue; } From 4472ef20df7a1b8dc7a3f0a9cf968fd06db9584c Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 3 Nov 2025 07:01:24 +0000 Subject: [PATCH 2/6] set io_uring flag=1 --- benchmarks/example_config.json | 4 ++-- docs/flexkv_config_reference/README_en.md | 4 ++-- docs/flexkv_config_reference/README_zh.md | 2 +- flexkv/common/config.py | 11 +++++------ tests/test_utils.py | 3 ++- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index a51b5381aa..0aea8e5e3f 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -27,8 +27,8 @@ "transfer_sms_d2h": 8, "max_blocks_per_file": 32000, "ssd_cache_dir": "./ssd_cache1/", - "ssd_cache_iouring_entries": 32, - "ssd_cache_iouring_flags": 0, + "ssd_cache_iouring_entries": 512, + "ssd_cache_iouring_flags": 1, "remote_cache_size_mode": "file_size", "remote_file_size": null, "remote_file_num": null, diff --git a/docs/flexkv_config_reference/README_en.md b/docs/flexkv_config_reference/README_en.md index 8df099ba41..c20116dead 100644 --- a/docs/flexkv_config_reference/README_en.md +++ b/docs/flexkv_config_reference/README_en.md @@ -105,7 +105,7 @@ The FlexKV configuration file is a JSON file, primarily consisting of three part | `max_blocks_per_file` | int | 32000 | Maximum number of blocks per SSD file. `-1` means unlimited. | | `ssd_cache_dir` | str \| List[str] | None | **Required.** Path to SSD cache directory, e.g., `"/data/flexkv_ssd/"`. | | `ssd_cache_iouring_entries` | int | 0 | io_uring queue depth. Recommended: `512` for significantly improved concurrent I/O performance. | -| `ssd_cache_iouring_flags` | int | 0 | io_uring flags. Keep as `0` in most cases. | +| `ssd_cache_iouring_flags` | int | 0 | io_uring flags. Recommended: `1`.| > Note: To maximize bandwidth across multiple SSDs, bind each SSD to a separate directory and specify them as a list: > `"ssd_cache_dir": ["/data0/flexkv_ssd/", "/data1/flexkv_ssd/"]`. @@ -144,4 +144,4 @@ The FlexKV configuration file is a JSON file, primarily consisting of three part | Parameter Name | Type | Default | Description | |----------------|------|---------|-------------| -| `evict_ratio` | float | 0.0 | Ratio of blocks to proactively evict from CPU/SSD per eviction cycle. `0.0` = evict only the minimal necessary blocks (more eviction cycles may impact performance). Recommended: `0.05` (evict 5% of least recently used blocks per cycle). | \ No newline at end of file +| `evict_ratio` | float | 0.0 | Ratio of blocks to proactively evict from CPU/SSD per eviction cycle. `0.0` = evict only the minimal necessary blocks (more eviction cycles may impact performance). Recommended: `0.05` (evict 5% of least recently used blocks per cycle). | diff --git a/docs/flexkv_config_reference/README_zh.md b/docs/flexkv_config_reference/README_zh.md index 79d65ce06e..d1f7a3a279 100644 --- a/docs/flexkv_config_reference/README_zh.md +++ b/docs/flexkv_config_reference/README_zh.md @@ -105,7 +105,7 @@ FlexKV 的配置文件是一个 JSON 文件,主要包含三个部分: | `max_blocks_per_file` | int | 32000 | 单个 SSD 文件最多包含的 block 数。-1 表示无限制 | | `ssd_cache_dir` | str \| List[str] | None | SSD 缓存目录路径,**必须设置**,如 `"/data/flexkv_ssd/"` | | `ssd_cache_iouring_entries` | int | 0 | io_uring 队列深度,推荐设为 `512` 以提升并发 IO 性能,实测比不使用iouring提升较大,推荐使用512 | -| `ssd_cache_iouring_flags` | int | 0 | io_uring 标志位,一般保持 0 | +| `ssd_cache_iouring_flags` | int | 1 | io_uring 标志位,推荐设置为 1。| > 注:为了充分利用多块SSD的带宽上限,可以将多块SSD绑定至不同目录,并使用如 `"ssd cache dir": ["/data0/flexkv_ssd/", "/data1/flexkv_ssd/"]`方式初始化,SSD KVCache会均匀分布在所有SSD中,充分利用多个SSD带宽。 diff --git a/flexkv/common/config.py b/flexkv/common/config.py index aa1835bfe7..abd54af6e0 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -56,8 +56,8 @@ class CacheConfig: # ssd cache configs max_blocks_per_file: int = 32000 # -1 means no limit ssd_cache_dir: Optional[Union[str, List[str]]] = None - ssd_cache_iouring_entries: int = 0 - ssd_cache_iouring_flags: int = 0 + ssd_cache_iouring_entries: int = 512 + ssd_cache_iouring_flags: int = 1 # gds cache configs gds_cache_dir: Optional[Union[str, List[str]]] = None @@ -81,13 +81,12 @@ class CacheConfig: evict_ratio: float = 0.0 def __post_init__(self): - layout_fields = ['gpu_kv_layout_type', - 'cpu_kv_layout_type', - 'ssd_kv_layout_type', + layout_fields = ['gpu_kv_layout_type', + 'cpu_kv_layout_type', + 'ssd_kv_layout_type', 'remote_kv_layout_type', 'gds_kv_layout_type'] for field in layout_fields: value = getattr(self, field) if isinstance(value, str): setattr(self, field, KVCacheLayoutType[value.upper()]) - diff --git a/tests/test_utils.py b/tests/test_utils.py index 86a64ef505..0297c0e5cb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -39,7 +39,8 @@ 'enable_gds': False, 'enable_trace': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], - 'ssd_cache_iouring_entries': 32, + 'ssd_cache_iouring_entries': 512, + 'ssd_cache_iouring_flags': 1, 'remote_cache_path': ["remote_cache1", "remote_cache2"], 'remote_config_custom': { "pcfs_fsid": "f_l91fz6", From ef86282b8639fa93cc991f388702f3cb1dc62e07 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 3 Nov 2025 08:17:01 +0000 Subject: [PATCH 3/6] batch sync for iouring --- csrc/transfer_ssd.h | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/csrc/transfer_ssd.h b/csrc/transfer_ssd.h index 9d6f9ba57e..3cf12f7ff4 100644 --- a/csrc/transfer_ssd.h +++ b/csrc/transfer_ssd.h @@ -56,21 +56,27 @@ class IOUring { } int wait_completion() { + constexpr int MAX_CQES = 32; + io_uring_cqe *cqes[MAX_CQES]; while (total_completed < total_submitted) { - if (io_uring_wait_cqe(&ring, &cqe) < 0) { - continue; + unsigned count = io_uring_peek_batch_cqe(&ring, cqes, MAX_CQES); + if (count == 0) { + if (io_uring_wait_cqe(&ring, &cqe) < 0) + continue; + count = 1; + cqes[0] = cqe; } - if (cqe->res < 0) { - fprintf(stderr, "IOUring(%p), cqe->res = %d\n", this, cqe->res); - cqe_err++; + for (unsigned i = 0; i < count; i++) { + if (cqes[i]->res < 0) { + cqe_err++; + } + iov2 = reinterpret_cast(io_uring_cqe_get_data(cqes[i])); + delete iov2; } - - iov2 = reinterpret_cast(io_uring_cqe_get_data(cqe)); - io_uring_cqe_seen(&ring, cqe); - total_completed++; - inflight--; - delete iov2; + total_completed += count; + inflight -= count; + io_uring_cq_advance(&ring, count); } if (cqe_err) { From 1b147d21429d115a0b1111c98c519cf7f2824574 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 3 Nov 2025 09:15:30 +0000 Subject: [PATCH 4/6] bench bidirection transfer --- benchmarks/benchmark_workers.py | 71 +++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/benchmarks/benchmark_workers.py b/benchmarks/benchmark_workers.py index ac8e9a1832..f1694ccd52 100644 --- a/benchmarks/benchmark_workers.py +++ b/benchmarks/benchmark_workers.py @@ -27,6 +27,7 @@ class BenchmarkConfig: shuffle_ids: bool = False warmup_round: int = 1 benchmark_round: int = 10 + bidirectional: bool = False def make_configs(args: dict) -> Tuple[ModelConfig, CacheConfig, BenchmarkConfig]: config_file = args.config @@ -44,6 +45,7 @@ def make_configs(args: dict) -> Tuple[ModelConfig, CacheConfig, BenchmarkConfig] bench_config.shuffle_ids = args.shuffle_ids bench_config.warmup_round = args.warmup_round bench_config.benchmark_round = args.benchmark_round + bench_config.bidirectional = args.bi return model_config, cache_config, bench_config except Exception as e: raise ValueError(f"Failed to load config file {config_file}: {e}") from None @@ -86,7 +88,7 @@ def create_cpu_gpu_worker( # max_op_num=4, max_block_num should be larger than num_blocks_to_transfer max_block_num = max(1024, cache_config.num_cpu_blocks) op_buffer_tensor = torch.empty((4, max_block_num), dtype=torch.int64).share_memory_() - + if model_config.tp_size == 1: worker_handle = GPUCPUTransferWorker.create_worker( mp_ctx=mp.get_context('spawn'), @@ -161,7 +163,7 @@ def create_cpu_ssd_worker( # max_op_num=4, max_block_num should be larger than num_blocks_to_transfer max_block_num = max(1024, cache_config.num_cpu_blocks) op_buffer_tensor = torch.empty((4, max_block_num), dtype=torch.int64).share_memory_() - + worker_handle = CPUSSDDiskTransferWorker.create_worker( mp_ctx=mp.get_context('spawn'), finished_ops_queue=finished_ops_queue, @@ -182,11 +184,18 @@ def create_cpu_ssd_worker( def launch_transfer(worker_handle: WorkerHandle, finished_ops_queue: mp.Queue, transfer_op: TransferOp): - op_id = transfer_op.op_id worker_handle.submit_transfer(transfer_op) - ret_op_id = finished_ops_queue.get() - assert ret_op_id == op_id - return True + +def sync_all(finished_ops_queue: mp.Queue, num_ops: int): + for _ in range(num_ops): + finished_ops_queue.get() + +REVERSE_TYPE_MAP = { + TransferType.D2H: TransferType.H2D, + TransferType.H2D: TransferType.D2H, + TransferType.DISK2H: TransferType.H2DISK, + TransferType.H2DISK: TransferType.DISK2H, + } def bench_worker(args): model_config, cache_config, bench_config = make_configs(args) @@ -204,6 +213,7 @@ def bench_worker(args): num_layers_to_transfer = model_config.num_layers num_blocks_to_transfer = bench_config.num_blocks_to_transfer shuffle_ids = bench_config.shuffle_ids + bidirectional = bench_config.bidirectional if transfer_type == TransferType.H2D or transfer_type == TransferType.D2H: worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config) @@ -213,6 +223,15 @@ def bench_worker(args): raise ValueError(f"Unsupported transfer type: {transfer_type} for benchmark, " f"currently only support {TransferType.H2D.name}, {TransferType.D2H.name}, " f"{TransferType.H2DISK.name}, {TransferType.DISK2H.name}") + reverse_worker_handle = None + reverse_finished_ops_queue = None + if bidirectional: + if transfer_type == TransferType.H2D or transfer_type == TransferType.D2H: + reverse_worker_handle, reverse_finished_ops_queue = \ + create_cpu_gpu_worker(model_config, cache_config) + elif transfer_type == TransferType.H2DISK or transfer_type == TransferType.DISK2H: + reverse_worker_handle, reverse_finished_ops_queue = \ + create_cpu_ssd_worker(model_config, cache_config) if shuffle_ids: block_ids = torch.randperm(num_blocks_to_transfer).numpy() @@ -230,21 +249,54 @@ def bench_worker(args): successors=[], predecessors=[], ) - if transfer_type == TransferType.DISK2H: + + reverse_transfer_op = None + if bidirectional: + reverse_type = REVERSE_TYPE_MAP.get(transfer_type) + if reverse_type is None: + raise ValueError(f"Bidirectional test not supported for transfer type: {transfer_type}") + + reverse_block_ids = torch.randperm(num_blocks_to_transfer).numpy() + + reverse_transfer_op = TransferOp( + transfer_type=reverse_type, + layer_id=0, + layer_granularity=num_layers_to_transfer, + src_block_ids=reverse_block_ids, + dst_block_ids=reverse_block_ids, + graph_id=1, + dp_id=0, + successors=[], + predecessors=[], + ) + if transfer_type == TransferType.DISK2H or transfer_type == TransferType.H2DISK: tmp_op = copy.deepcopy(transfer_op) tmp_op.transfer_type = TransferType.H2DISK tmp_op.src_block_ids = transfer_op.dst_block_ids tmp_op.dst_block_ids = transfer_op.src_block_ids launch_transfer(worker_handle, finished_ops_queue, tmp_op) + sync_all(finished_ops_queue, 1) + for _ in range(warmup_round): + if bidirectional: + launch_transfer(reverse_worker_handle, reverse_finished_ops_queue, reverse_transfer_op) launch_transfer(worker_handle, finished_ops_queue, transfer_op) + sync_all(finished_ops_queue, warmup_round) + if bidirectional: + sync_all(reverse_finished_ops_queue, warmup_round) + pbar = tqdm(total=benchmark_round, desc="Benchmarking") start_time = time.time() for _ in range(benchmark_round): + if bidirectional: + launch_transfer(reverse_worker_handle, reverse_finished_ops_queue, reverse_transfer_op) launch_transfer(worker_handle, finished_ops_queue, transfer_op) pbar.update(1) pbar.close() + sync_all(finished_ops_queue, benchmark_round) end_time = time.time() + if bidirectional: + sync_all(reverse_finished_ops_queue, benchmark_round) total_data_size_GB = ( num_blocks_to_transfer * cache_config.tokens_per_block * @@ -257,6 +309,8 @@ def bench_worker(args): print(f"Avg Time taken: {avg_time} seconds") print(f"Avg Bandwidth: {total_data_size_GB / avg_time} GB/s") worker_handle.shutdown() + if bidirectional: + reverse_worker_handle.shutdown() def parse_args(): parser = ArgumentParser() @@ -280,6 +334,9 @@ def parse_args(): parser.add_argument("--benchmark-round", type=int, default=10) + parser.add_argument("--bi", + action="store_true", + help="benchmark bidirectional bandwidth") return parser.parse_args() if __name__ == "__main__": From c799c930ca339b183edea810e3a3fd1a97d0b109 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 3 Nov 2025 10:36:42 +0000 Subject: [PATCH 5/6] swap loop to improve multi-SSD bandwidth --- csrc/transfer_ssd.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/transfer_ssd.cpp b/csrc/transfer_ssd.cpp index 4aa72c6888..4cafc266dc 100644 --- a/csrc/transfer_ssd.cpp +++ b/csrc/transfer_ssd.cpp @@ -255,8 +255,8 @@ void transfer_kv_blocks_ssd( std::vector threads; std::vector> futures; - for (int d = 0; d < num_devices; d++) { - for (int t = 0; t < num_threads_per_device; t++) { + for (int t = 0; t < num_threads_per_device; t++) { + for (int d = 0; d < num_devices; d++) { int start_layer = cpu_layer_id_list_ptr[0]; int end_layer = cpu_layer_id_list_ptr[0] + num_layers; int num_transfer_blocks = cpu_blocks_partition[d].size(); @@ -302,8 +302,8 @@ void transfer_kv_blocks_ssd( } }); } - } // end thread loop - } // end device loop + } // end device loop + } // end thread loop if (iouring.enabled()) { if (iouring.wait_completion()) { From 99bbd9f3f9c3978c37cea884d8ff451b7d70504f Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 3 Nov 2025 10:38:20 +0000 Subject: [PATCH 6/6] prefer read --- csrc/transfer_ssd.h | 16 +++++++++ flexkv/transfer/worker.py | 73 +++++++++++++++++---------------------- 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/csrc/transfer_ssd.h b/csrc/transfer_ssd.h index 3cf12f7ff4..04d2cbeb17 100644 --- a/csrc/transfer_ssd.h +++ b/csrc/transfer_ssd.h @@ -1,9 +1,19 @@ #pragma once #include #include +#include #include #include +#ifndef IOPRIO_CLASS_SHIFT +#define IOPRIO_CLASS_SHIFT 13 +#define IOPRIO_PRIO_MASK ((1UL << IOPRIO_CLASS_SHIFT) - 1) +#define IOPRIO_PRIO_VALUE(class, data) (((class) << IOPRIO_CLASS_SHIFT) | data) +#define IOPRIO_CLASS_RT 1 +#define IOPRIO_CLASS_BE 2 +#define IOPRIO_CLASS_IDLE 3 +#endif + namespace flexkv { class IOUring { @@ -102,6 +112,9 @@ class IOUring { iov->iov_base = ptr; iov->iov_len = size; io_uring_prep_readv(sqe, fd, iov, 1, offset); + + sqe->ioprio = IOPRIO_PRIO_VALUE(IOPRIO_CLASS_RT, 0); + io_uring_sqe_set_data(sqe, iov); prepared++; return 0; @@ -122,6 +135,9 @@ class IOUring { iov->iov_base = ptr; iov->iov_len = size; io_uring_prep_writev(sqe, fd, iov, 1, offset); + + sqe->ioprio = IOPRIO_PRIO_VALUE(IOPRIO_CLASS_BE, 4); + io_uring_sqe_set_data(sqe, iov); prepared++; return 0; diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index f192e70fce..254e0b4e28 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -16,7 +16,8 @@ from flexkv import c_ext -from flexkv.c_ext import transfer_kv_blocks, transfer_kv_blocks_ssd, transfer_kv_blocks_gds, TPTransferThreadGroup, TPGDSTransferThreadGroup +from flexkv.c_ext import transfer_kv_blocks, transfer_kv_blocks_ssd, \ + transfer_kv_blocks_gds, TPTransferThreadGroup, TPGDSTransferThreadGroup from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType @@ -322,18 +323,6 @@ def __init__(self, self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h - print(f"GPU block type: {self.gpu_block_type_}") - print(f"GPU blocks pointers: {self.gpu_blocks_ptrs}") - print(f"GPU tensor pointers: {self.gpu_tensor_ptrs}") - print(f"chunk size: {self.chunk_size_in_bytes}") - print(f"gpu kv stride: {self.gpu_kv_stride_in_bytes}") - print(f"gpu block stride: {self.gpu_block_stride_in_bytes}") - print(f"gpu layer stride: {self.gpu_layer_stride_in_bytes}") - print(f"cpu layer stride: {self.cpu_layer_stride_in_bytes}") - print(f"cpu kv stride: {self.cpu_kv_stride_in_bytes}") - print(f"cpu block stride: {self.cpu_block_stride_in_bytes}") - print(f"num layers: {self.num_layers}") - def _transfer_impl( self, src_block_ids: torch.Tensor, @@ -457,7 +446,7 @@ def __init__(self, cudaHostRegister(cpu_blocks) self.num_layers = gpu_kv_layouts[0].num_layer - + # here the chunk size doesn't include the layer info self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout.get_chunk_size() * self.dtype.itemsize \ for gpu_kv_layout in gpu_kv_layouts] @@ -483,8 +472,9 @@ def __init__(self, gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, - self.num_layers, gpu_kv_strides_tensor, - gpu_block_strides_tensor, gpu_layer_strides_tensor, gpu_chunk_sizes_tensor) + self.num_layers, gpu_kv_strides_tensor, + gpu_block_strides_tensor, gpu_layer_strides_tensor, + gpu_chunk_sizes_tensor) def _transfer_impl(self, @@ -871,7 +861,7 @@ def __init__( ) -> None: """ Initialize GDS Transfer Worker - + Args: worker_id: Worker ID transfer_queue: Queue for incoming transfer operations @@ -880,42 +870,43 @@ def __init__( gds_file_paths: List of GDS file paths (will create GDSManager from these) num_blocks_per_file: Number of blocks per file gpu_kv_layout: Layout of GPU KV cache - gds_kv_layout: Layout of GDS KV cache + gds_kv_layout: Layout of GDS KV cache dtype: Data type gpu_device_id: GPU device ID """ # Initialize base class first super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) - + self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] self.gpu_blocks_ptrs = self._get_layer_ptrs(self.gpu_blocks) self.gpu_layer_ptrs = self.gpu_blocks_ptrs self.num_blocks_per_file = num_blocks_per_file - + # Create GDSManager from file paths in this worker process from flexkv import c_ext self.gds_manager = c_ext.GDSManager(gds_file_paths) if not self.gds_manager.is_ready(): - raise RuntimeError(f"Failed to initialize GDS Manager in worker {worker_id}: {self.gds_manager.get_last_error()}") - + raise RuntimeError(f"Failed to initialize GDS Manager in worker {worker_id}: " + f"{self.gds_manager.get_last_error()}") + self.dtype = dtype self.is_mla = gpu_kv_layout.is_mla - + # Layout information self.num_layers = gpu_kv_layout.num_layer gpu_kv_layout_per_layer = gpu_kv_layout.div_layer(self.num_layers) gds_kv_layout_per_layer = gds_kv_layout.div_layer(self.num_layers) - + # GPU layout calculations self.chunk_size_in_bytes = gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize self.gpu_kv_stride_in_bytes = gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize self.gpu_block_stride_in_bytes = gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize - - # GDS layout calculations + + # GDS layout calculations self.gds_layer_stride_in_bytes = gds_kv_layout.get_layer_stride() * self.dtype.itemsize self.gds_kv_stride_in_bytes = gds_kv_layout.get_kv_stride() * self.dtype.itemsize self.gds_block_stride_in_bytes = gds_kv_layout.get_block_stride() * self.dtype.itemsize - + # Set GPU device and create stream self.gpu_device_id = gpu_device_id if gpu_device_id != -1: @@ -958,15 +949,15 @@ def _transfer_impl( # Process transfer for each layer layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) - + # Get managed GDS files gds_files = self.gds_manager.get_managed_files() if not gds_files: raise RuntimeError("No GDS files available") - + # Determine if this is a read operation is_read = (transfer_type == TransferType.GDS2D) - + # Use the optimized C++ function for KV block transfers with multi-file support try: transfer_kv_blocks_gds( @@ -978,7 +969,7 @@ def _transfer_impl( gpu_block_id_list, # GPU block IDs self.gpu_kv_stride_in_bytes, # GPU K-V stride self.gds_layer_stride_in_bytes, # GDS layer stride - self.gds_block_stride_in_bytes, # GDS block stride + self.gds_block_stride_in_bytes, # GDS block stride self.gds_kv_stride_in_bytes, # GDS K-V stride self.chunk_size_in_bytes, # Block size self.num_blocks_per_file, # Blocks per file @@ -987,10 +978,10 @@ def _transfer_impl( False, # Verbose logging self.is_mla # MLA ) - + except Exception as e: flexkv_logger.error(f"GDS transfer failed: {e}") - raise RuntimeError(f"Failed to transfer KV blocks: {e}") + raise RuntimeError(f"Failed to transfer KV blocks: {e}") from e def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: """Launch a GDS transfer operation""" @@ -1043,7 +1034,7 @@ def __init__( ) -> None: """ Initialize TP GDS Transfer Worker - + Args: worker_id: Worker ID transfer_queue: Queue for incoming transfer operations @@ -1059,7 +1050,7 @@ def __init__( """ # Initialize base class first super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) - + assert len(gpu_blocks) == tp_group_size # Handle tensor import for multi-process case imported_gpu_blocks = [] @@ -1070,34 +1061,34 @@ def __init__( imported_gpu_blocks.append(blocks_in_one_gpu) self.gpu_blocks = imported_gpu_blocks self.num_blocks_per_file = num_blocks_per_file - + self.dtype = dtype self.is_mla = gpu_kv_layout.is_mla self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size self.dp_group_id = dp_group_id - + # Layout information self.num_layers = gpu_kv_layout.num_layer gpu_kv_layout_per_layer = gpu_kv_layout.div_layer(self.num_layers) gds_kv_layout_per_layer = gds_kv_layout.div_layer(self.num_layers) - + # GPU layout calculations self.gpu_chunk_size_in_bytes = gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize self.gpu_kv_stride_in_bytes = gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize self.gpu_block_stride_in_bytes = gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize - + # GDS layout calculations self.gds_chunk_size_in_bytes = gds_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize self.gds_layer_stride_in_bytes = gds_kv_layout.get_layer_stride() * self.dtype.itemsize self.gds_kv_stride_in_bytes = gds_kv_layout.get_kv_stride() * self.dtype.itemsize self.gds_block_stride_in_bytes = gds_kv_layout.get_block_stride() * self.dtype.itemsize - + if not gpu_kv_layout.type == KVCacheLayoutType.LAYERWISE: raise ValueError("Only layerwise layout is supported for GPU") if not gds_kv_layout.type == KVCacheLayoutType.LAYERWISE: raise ValueError("Only layerwise layout is supported for GDS") - + # Create TP GDS Transfer Thread Group self.tp_gds_transfer_thread_group = TPGDSTransferThreadGroup( self.num_gpus, self.gpu_blocks, gds_file_paths, dp_group_id)