Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 56 additions & 44 deletions csrc/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <cstddef>
#include <cstdint>
#include <map>
#include <stdexcept>
#include <vector>
#include <map>

#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
Expand All @@ -16,36 +16,38 @@

#include "cache_utils.h"
#include "pcfs/pcfs.h"
#include "radix_tree.h"
#include "tp_transfer_thread_group.h"
#include "transfer.cuh"
#include "transfer_ssd.h"
#include "radix_tree.h"

namespace py = pybind11;

void transfer_kv_blocks_binding(
torch::Tensor &gpu_block_id_tensor, torch::Tensor &gpu_layer_ptrs_tensor,
int64_t gpu_kv_stride_in_bytes, int64_t gpu_block_stride_in_bytes,
torch::Tensor &gpu_block_id_tensor, torch::Tensor &k_gpu_layer_ptrs_tensor,
torch::Tensor &v_gpu_layer_ptrs_tensor, int64_t gpu_block_stride_in_bytes,
torch::Tensor &cpu_block_id_tensor, torch::Tensor &cpu_tensor,
int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes,
int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes,
int start_layer_id, int transfer_sms = -1, bool is_host_to_device = true,
bool use_ce_transfer = false, bool is_mla = false) {
int num_blocks = gpu_block_id_tensor.numel();
int num_layers = gpu_layer_ptrs_tensor.numel();
int num_layers = k_gpu_layer_ptrs_tensor.numel();

int64_t *gpu_block_ids =
static_cast<int64_t *>(gpu_block_id_tensor.data_ptr());
void **gpu_layer_ptrs = static_cast<void **>(
gpu_layer_ptrs_tensor.data_ptr()); // must be contiguous
void **k_gpu_layer_ptrs = static_cast<void **>(
k_gpu_layer_ptrs_tensor.data_ptr()); // must be contiguous
void **v_gpu_layer_ptrs = static_cast<void **>(
v_gpu_layer_ptrs_tensor.data_ptr()); // must be contiguous
int64_t *cpu_block_ids =
static_cast<int64_t *>(cpu_block_id_tensor.data_ptr());
void *cpu_ptr = static_cast<void *>(cpu_tensor.data_ptr());

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
flexkv::transfer_kv_blocks(
num_blocks, start_layer_id, num_layers, gpu_block_ids, gpu_layer_ptrs,
gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, 0, cpu_block_ids, cpu_ptr,
num_blocks, start_layer_id, num_layers, gpu_block_ids, k_gpu_layer_ptrs,
v_gpu_layer_ptrs, gpu_block_stride_in_bytes, 0, cpu_block_ids, cpu_ptr,
cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
Expand All @@ -56,22 +58,21 @@ void transfer_kv_blocks_binding(
}

void transfer_kv_blocks_ssd_binding(
flexkv::SSDIOCTX &ioctx,
const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr,
const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids,
int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes,
int64_t ssd_layer_stride_in_bytes, int64_t ssd_kv_stride_in_bytes,
int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read,
int num_blocks_per_file, int round_robin = 1,
int num_threads_per_device = 8, bool is_mla = false) {
flexkv::SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list,
int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids,
const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes,
int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes,
int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes,
int64_t block_stride_in_bytes, bool is_read, int num_blocks_per_file,
int round_robin = 1, int num_threads_per_device = 8, bool is_mla = false) {
TORCH_CHECK(ssd_block_ids.dtype() == torch::kInt64,
"ssd_block_ids must be int64");
TORCH_CHECK(cpu_block_ids.dtype() == torch::kInt64,
"cpu_block_ids must be int64");

flexkv::transfer_kv_blocks_ssd(
ioctx, cpu_layer_id_list, cpu_tensor_ptr, ssd_block_ids,
cpu_block_ids, cpu_layer_stride_in_bytes, cpu_kv_stride_in_bytes,
ioctx, cpu_layer_id_list, cpu_tensor_ptr, ssd_block_ids, cpu_block_ids,
cpu_layer_stride_in_bytes, cpu_kv_stride_in_bytes,
ssd_layer_stride_in_bytes, ssd_kv_stride_in_bytes, chunk_size_in_bytes,
block_stride_in_bytes, is_read, num_blocks_per_file, round_robin,
num_threads_per_device, is_mla);
Expand Down Expand Up @@ -109,15 +110,15 @@ PYBIND11_MODULE(c_ext, m) {
m.def("transfer_kv_blocks", &transfer_kv_blocks_binding,
"Transfer multi-layer KV-cache between CPU and GPU");
m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding,
"Transfer KV blocks between SSD and CPU memory",
py::arg("ioctx"), py::arg("cpu_layer_id_list"),
py::arg("cpu_tensor_ptr"), py::arg("ssd_block_ids"),
py::arg("cpu_block_ids"), py::arg("cpu_layer_stride_in_bytes"),
py::arg("cpu_kv_stride_in_bytes"), py::arg("ssd_layer_stride_in_bytes"),
py::arg("ssd_kv_stride_in_bytes"), py::arg("chunk_size_in_bytes"),
py::arg("block_stride_in_bytes"), py::arg("is_read"),
py::arg("num_blocks_per_file"), py::arg("round_robin") = 1,
py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false);
"Transfer KV blocks between SSD and CPU memory", py::arg("ioctx"),
py::arg("cpu_layer_id_list"), py::arg("cpu_tensor_ptr"),
py::arg("ssd_block_ids"), py::arg("cpu_block_ids"),
py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_kv_stride_in_bytes"),
py::arg("ssd_layer_stride_in_bytes"), py::arg("ssd_kv_stride_in_bytes"),
py::arg("chunk_size_in_bytes"), py::arg("block_stride_in_bytes"),
py::arg("is_read"), py::arg("num_blocks_per_file"),
py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16,
py::arg("is_mla") = false);
#ifdef FLEXKV_ENABLE_CFS
m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote,
"Transfer KV blocks between remote and CPU memory",
Expand All @@ -140,11 +141,12 @@ PYBIND11_MODULE(c_ext, m) {
py::arg("block_hashes"));

py::class_<flexkv::SSDIOCTX>(m, "SSDIOCTX")
.def(py::init<std::map<int, std::vector<std::string>> &, int, int, int>());
.def(
py::init<std::map<int, std::vector<std::string>> &, int, int, int>());

py::class_<flexkv::TPTransferThreadGroup>(m, "TPTransferThreadGroup")
.def(py::init<int, const std::vector<std::vector<torch::Tensor>> &,
torch::Tensor &, int, torch::Tensor &, torch::Tensor &, torch::Tensor &>())
torch::Tensor &, int, torch::Tensor &, torch::Tensor &>())
.def("tp_group_transfer",
&flexkv::TPTransferThreadGroup::tp_group_transfer,
py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"),
Expand Down Expand Up @@ -201,29 +203,39 @@ PYBIND11_MODULE(c_ext, m) {
.def("reset", &flexkv::CRadixTreeIndex::reset)
.def("lock", &flexkv::CRadixTreeIndex::lock, py::arg("node"))
.def("unlock", &flexkv::CRadixTreeIndex::unlock, py::arg("node"))
.def("set_ready", &flexkv::CRadixTreeIndex::set_ready,
py::arg("node"), py::arg("ready"), py::arg("ready_length"))
.def("insert", &flexkv::CRadixTreeIndex::insert, py::return_value_policy::reference,
py::arg("physical_block_ids"), py::arg("block_hashes"), py::arg("num_blocks"),
py::arg("num_insert_blocks"), py::arg("ready") = true, py::arg("node") = nullptr,
py::arg("num_matched_blocks") = -1, py::arg("last_node_matched_length") = -1)
.def("evict", &flexkv::CRadixTreeIndex::evict, py::arg("evicted_blocks"), py::arg("num_evicted"))
.def("set_ready", &flexkv::CRadixTreeIndex::set_ready, py::arg("node"),
py::arg("ready"), py::arg("ready_length"))
.def("insert", &flexkv::CRadixTreeIndex::insert,
py::return_value_policy::reference, py::arg("physical_block_ids"),
py::arg("block_hashes"), py::arg("num_blocks"),
py::arg("num_insert_blocks"), py::arg("ready") = true,
py::arg("node") = nullptr, py::arg("num_matched_blocks") = -1,
py::arg("last_node_matched_length") = -1)
.def("evict", &flexkv::CRadixTreeIndex::evict, py::arg("evicted_blocks"),
py::arg("num_evicted"))
.def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks)
.def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks)
.def("total_unready_blocks",
&flexkv::CRadixTreeIndex::total_unready_blocks)
.def("total_ready_blocks", &flexkv::CRadixTreeIndex::total_ready_blocks)
.def("match_prefix", &flexkv::CRadixTreeIndex::match_prefix,
py::arg("block_hashes"), py::arg("num_blocks"), py::arg("update_cache_info"));
py::arg("block_hashes"), py::arg("num_blocks"),
py::arg("update_cache_info"));

py::class_<flexkv::CRadixNode>(m, "CRadixNode")
.def(py::init<flexkv::CRadixTreeIndex *, bool, int>())
.def("size", &flexkv::CRadixNode::size);

py::class_<flexkv::CMatchResult, std::shared_ptr<flexkv::CMatchResult>>(m, "CMatchResult")
.def(py::init<int, int, int, flexkv::CRadixNode *, flexkv::CRadixNode *, std::vector<int64_t> *>())
py::class_<flexkv::CMatchResult, std::shared_ptr<flexkv::CMatchResult>>(
m, "CMatchResult")
.def(py::init<int, int, int, flexkv::CRadixNode *, flexkv::CRadixNode *,
std::vector<int64_t> *>())
.def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node)
.def_readonly("last_node", &flexkv::CMatchResult::last_node)
.def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks)
.def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks)
.def_readonly("num_matched_blocks", &flexkv::CMatchResult::num_matched_blocks)
.def_readonly("last_node_matched_length", &flexkv::CMatchResult::last_node_matched_length);
.def_readonly("num_ready_matched_blocks",
&flexkv::CMatchResult::num_ready_matched_blocks)
.def_readonly("num_matched_blocks",
&flexkv::CMatchResult::num_matched_blocks)
.def_readonly("last_node_matched_length",
&flexkv::CMatchResult::last_node_matched_length);
}
112 changes: 64 additions & 48 deletions csrc/tp_transfer_thread_group.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,36 +23,44 @@ namespace flexkv {
TPTransferThreadGroup::TPTransferThreadGroup(
int num_gpus, const std::vector<std::vector<torch::Tensor>> &gpu_blocks,
torch::Tensor &cpu_blocks, int dp_group_id,
torch::Tensor &gpu_kv_strides_tensor,
torch::Tensor &gpu_block_strides_tensor,
torch::Tensor &gpu_chunk_sizes_tensor) {

num_gpus_ = num_gpus;

gpu_kv_strides_in_bytes_ = new int64_t[num_gpus];

gpu_block_strides_in_bytes_ = new int64_t[num_gpus];
gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus];

int64_t* kv_strides_ptr = gpu_kv_strides_tensor.data_ptr<int64_t>();
int64_t* block_strides_ptr = gpu_block_strides_tensor.data_ptr<int64_t>();
int64_t* chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr<int64_t>();


int64_t *block_strides_ptr = gpu_block_strides_tensor.data_ptr<int64_t>();
int64_t *chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr<int64_t>();

for (int i = 0; i < num_gpus; i++) {
gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i];
gpu_block_strides_in_bytes_[i] = block_strides_ptr[i];
gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i];
}

queues_.resize(num_gpus_);
mtxs_ = std::vector<std::mutex>(num_gpus_);
cvs_ = std::vector<std::condition_variable>(num_gpus_);
mtxs_ = std::vector<std::mutex>(num_gpus_);
cvs_ = std::vector<std::condition_variable>(num_gpus_);

int num_layers = gpu_blocks[0].size();
cudaMallocHost((void **)&gpu_blocks_,
cudaMallocHost((void **)&k_gpu_blocks_,
num_gpus_ * num_layers * sizeof(void *));
cudaMallocHost((void **)&v_gpu_blocks_,
num_gpus_ * num_layers * sizeof(void *));

for (int i = 0; i < num_gpus_; ++i) {
for (int j = 0; j < num_layers; ++j) {
gpu_blocks_[i * num_layers + j] = gpu_blocks[i][j].data_ptr();
// gpu_blocks[gpu][layer]: [kv_dim, num_blocks, num_kv_heads, head_size]
// kv_dim=2: gpu_blocks = k_caches + v_caches
torch::Tensor kv_tensor = gpu_blocks[i][j];
if (kv_tensor.size(0) == 2) { // Non-MLA case: [2, num_blocks, ...]
k_gpu_blocks_[i * num_layers + j] = kv_tensor[0].data_ptr();
v_gpu_blocks_[i * num_layers + j] = kv_tensor[1].data_ptr();
} else { // MLA case: [1, num_blocks, ...]
k_gpu_blocks_[i * num_layers + j] = kv_tensor.data_ptr();
v_gpu_blocks_[i * num_layers + j] = nullptr; // V not used in MLA
}
}
}

Expand All @@ -65,47 +73,51 @@ TPTransferThreadGroup::TPTransferThreadGroup(
cudaStreamCreate(&streams_[i]);
}
// create the thread pool
stop_pool_=false;
stop_pool_ = false;
for (int i = 0; i < num_gpus_; ++i) {
threads_.emplace_back([this, i]() {
int device_id = dp_group_id_ * num_gpus_ + i;
cudaSetDevice(device_id); // only once
cudaSetDevice(device_id); // only once

while (true) {
Task task;
{
std::unique_lock<std::mutex> lk(mtxs_[i]);
cvs_[i].wait(lk, [&]{ return stop_pool_ || !queues_[i].empty(); });
if (stop_pool_ && queues_[i].empty()) return;
cvs_[i].wait(lk, [&] { return stop_pool_ || !queues_[i].empty(); });
if (stop_pool_ && queues_[i].empty())
return;

task = std::move(queues_[i].front());
queues_[i].pop();
}
task(); //
task(); //
}
});
}

}

TPTransferThreadGroup::~TPTransferThreadGroup() {
stop_pool_ = true;
for (auto& cv : cvs_) cv.notify_all();
for (auto& t : threads_) if (t.joinable()) t.join();
for (auto &cv : cvs_)
cv.notify_all();
for (auto &t : threads_)
if (t.joinable())
t.join();

cudaFreeHost(k_gpu_blocks_);
cudaFreeHost(v_gpu_blocks_);

cudaFreeHost(gpu_blocks_);

delete[] gpu_kv_strides_in_bytes_;
delete[] gpu_block_strides_in_bytes_;
delete[] gpu_chunk_sizes_in_bytes_;
}

std::future<void> TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx, Task task) {
std::future<void> TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx,
Task task) {
auto pkg = std::make_shared<std::packaged_task<void()>>(std::move(task));
auto fut = pkg->get_future();
{
std::lock_guard<std::mutex> lk(mtxs_[gpu_idx]);
queues_[gpu_idx].emplace([pkg]{ (*pkg)(); });
std::lock_guard<std::mutex> lk(mtxs_[gpu_idx]);
queues_[gpu_idx].emplace([pkg] { (*pkg)(); });
}
cvs_[gpu_idx].notify_one();
return fut;
Expand All @@ -130,7 +142,7 @@ void TPTransferThreadGroup::tp_group_transfer(
std::vector<std::future<void>> futures;
futures.reserve(num_gpus_);

for (int i=0; i<num_gpus_; ++i){
for (int i = 0; i < num_gpus_; ++i) {
futures.emplace_back(enqueue_for_gpu(i, [&, i]() {
try {
int num_blocks = gpu_block_id_tensor.numel();
Expand All @@ -140,31 +152,36 @@ void TPTransferThreadGroup::tp_group_transfer(
static_cast<int64_t *>(gpu_block_id_tensor.data_ptr());
int64_t *cpu_block_ids =
static_cast<int64_t *>(cpu_block_id_tensor.data_ptr());
void **gpu_layer_ptrs =
static_cast<void **>(gpu_blocks_ + i * num_layers + layer_id);
void **k_gpu_layer_ptrs =
static_cast<void **>(k_gpu_blocks_ + i * num_layers + layer_id);
void **v_gpu_layer_ptrs =
static_cast<void **>(v_gpu_blocks_ + i * num_layers + layer_id);
void *cpu_ptr = cpu_blocks_;
int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i];
if (is_mla && !is_host_to_device) {
cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_;
cpu_startoff_inside_chunks =
i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_;
} else if (is_mla && is_host_to_device) {
cpu_startoff_inside_chunks = 0;
}
int64_t gpu_startoff_inside_chunks =
is_mla && !is_host_to_device ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : 0;
int64_t gpu_startoff_inside_chunks =
is_mla && !is_host_to_device
? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_
: 0;
// we assume that the chunk size is the same for all gpus,
// even if they have different number of gpu_blocks
int64_t chunk_size = is_mla && !is_host_to_device ?
gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i];

int64_t chunk_size = is_mla && !is_host_to_device
? gpu_chunk_sizes_in_bytes_[i] / num_gpus_
: gpu_chunk_sizes_in_bytes_[i];

flexkv::transfer_kv_blocks(
num_blocks, layer_id, layer_granularity, gpu_block_ids,
gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i],
gpu_startoff_inside_chunks,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes,
cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes,
cpu_startoff_inside_chunks, chunk_size, streams_[i],
transfer_sms, is_host_to_device, use_ce_transfer, is_mla
);
num_blocks, layer_id, layer_granularity, gpu_block_ids,
k_gpu_layer_ptrs, v_gpu_layer_ptrs, gpu_block_strides_in_bytes_[i],
gpu_startoff_inside_chunks, cpu_block_ids, cpu_ptr,
cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size,
streams_[i], transfer_sms, is_host_to_device, use_ce_transfer,
is_mla);

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
Expand All @@ -175,11 +192,10 @@ void TPTransferThreadGroup::tp_group_transfer(
failed = true;
error_msg = e.what();
}

}));
}

for (auto &f : futures){
for (auto &f : futures) {
f.get();
}

Expand Down
Loading