Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <fcntl.h>
#include <nvToolsExt.h>
#include <nvtx3/nvToolsExt.h>
#include <pybind11/pybind11.h>
#include <sys/mman.h>
#include <sys/stat.h>
Expand All @@ -26,31 +26,72 @@
namespace py = pybind11;

void transfer_kv_blocks_binding(
torch::Tensor &gpu_block_id_tensor, torch::Tensor &gpu_layer_ptrs_tensor,
int64_t gpu_kv_stride_in_bytes, int64_t gpu_block_stride_in_bytes,
torch::Tensor &gpu_block_id_tensor, torch::Tensor &gpu_tensor_ptrs_tensor,
int64_t gpu_kv_stride_in_bytes, int64_t gpu_block_stride_in_bytes, int64_t gpu_layer_stride_in_bytes,
torch::Tensor &cpu_block_id_tensor, torch::Tensor &cpu_tensor,
int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes,
int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes,
int start_layer_id, int transfer_sms = -1, bool is_host_to_device = true,
bool use_ce_transfer = false, bool is_mla = false) {
int start_layer_id, int num_layers, int transfer_sms = -1, bool is_host_to_device = true,
bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0) {
int num_blocks = gpu_block_id_tensor.numel();
int num_layers = gpu_layer_ptrs_tensor.numel();

int64_t *gpu_block_ids =
static_cast<int64_t *>(gpu_block_id_tensor.data_ptr());
void **gpu_layer_ptrs = static_cast<void **>(
gpu_layer_ptrs_tensor.data_ptr()); // must be contiguous
void **gpu_tensor_ptrs = static_cast<void **>(
gpu_tensor_ptrs_tensor.data_ptr()); // must be contiguous
int64_t *cpu_block_ids =
static_cast<int64_t *>(cpu_block_id_tensor.data_ptr());
void *cpu_ptr = static_cast<void *>(cpu_tensor.data_ptr());

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
flexkv::transfer_kv_blocks(
num_blocks, start_layer_id, num_layers, gpu_block_ids, gpu_layer_ptrs,
gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, 0, cpu_block_ids, cpu_ptr,
cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);

// Determine backend type from gpu_block_type parameter
flexkv::BackendType backend_type;
if (gpu_block_type == 0) {
backend_type = flexkv::BackendType::VLLM;
} else if (gpu_block_type == 1) {
backend_type = flexkv::BackendType::TRTLLM;
} else if (gpu_block_type == 2) {
backend_type = flexkv::BackendType::SGLANG;
} else {
throw std::runtime_error("Unsupported gpu_block_type: " + std::to_string(gpu_block_type));
}

// Create GTensorHandler
flexkv::GTensorHandler handler(
backend_type,
reinterpret_cast<int64_t**>(gpu_tensor_ptrs),
num_layers,
gpu_kv_stride_in_bytes,
gpu_block_stride_in_bytes,
gpu_layer_stride_in_bytes
);

// Dispatch to appropriate template instantiation
switch (backend_type) {
case flexkv::BackendType::VLLM:
flexkv::transfer_kv_blocks<flexkv::BackendType::VLLM>(
num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
break;
case flexkv::BackendType::TRTLLM:
flexkv::transfer_kv_blocks<flexkv::BackendType::TRTLLM>(
num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
break;
case flexkv::BackendType::SGLANG:
flexkv::transfer_kv_blocks<flexkv::BackendType::SGLANG>(
num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
break;
}

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
Expand Down Expand Up @@ -258,7 +299,16 @@ bool create_gds_file_binding(GDSManager& manager,

PYBIND11_MODULE(c_ext, m) {
m.def("transfer_kv_blocks", &transfer_kv_blocks_binding,
"Transfer multi-layer KV-cache between CPU and GPU");
"Transfer multi-layer KV-cache between CPU and GPU",
py::arg("gpu_block_id_tensor"), py::arg("gpu_tensor_ptrs_tensor"),
py::arg("gpu_kv_stride_in_bytes"), py::arg("gpu_block_stride_in_bytes"),
py::arg("gpu_layer_stride_in_bytes"), py::arg("cpu_block_id_tensor"),
py::arg("cpu_tensor"), py::arg("cpu_kv_stride_in_bytes"),
py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"),
py::arg("chunk_size_in_bytes"), py::arg("start_layer_id"),
py::arg("num_layers"), py::arg("transfer_sms") = -1,
py::arg("is_host_to_device") = true, py::arg("use_ce_transfer") = false,
py::arg("is_mla") = false, py::arg("gpu_block_type") = 0);
m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding,
"Transfer KV blocks between SSD and CPU memory",
py::arg("ioctx"), py::arg("cpu_layer_id_list"),
Expand Down Expand Up @@ -303,7 +353,11 @@ PYBIND11_MODULE(c_ext, m) {

py::class_<flexkv::TPTransferThreadGroup>(m, "TPTransferThreadGroup")
.def(py::init<int, const std::vector<std::vector<torch::Tensor>> &,
torch::Tensor &, int, torch::Tensor &, torch::Tensor &, torch::Tensor &>())
torch::Tensor &, int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, torch::Tensor &>(),
py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"),
py::arg("dp_group_id"), py::arg("num_layers"),
py::arg("gpu_kv_strides_tensor"), py::arg("gpu_block_strides_tensor"),
py::arg("gpu_layer_strides_tensor"), py::arg("gpu_chunk_sizes_tensor"))
.def("tp_group_transfer",
&flexkv::TPTransferThreadGroup::tp_group_transfer,
py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"),
Expand Down
83 changes: 83 additions & 0 deletions csrc/gtensor_handler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#pragma once
#include <cuda_runtime.h>

namespace flexkv {

// Backend type enumeration
enum class BackendType {
VLLM,
TRTLLM,
SGLANG
};

// Simplified GTensorHandler - no inheritance, just a data structure
struct GTensorHandler {
BackendType type;
int64_t **gpu_tensor_ptrs;
int64_t num_layers;
int64_t gpu_kv_stride;
int64_t gpu_block_stride;
int64_t gpu_layer_stride;

__host__ __device__
GTensorHandler()
: type(BackendType::VLLM),
gpu_tensor_ptrs(nullptr),
num_layers(0),
gpu_kv_stride(0),
gpu_block_stride(0),
gpu_layer_stride(0) {}

__host__ __device__
GTensorHandler(BackendType type,
int64_t **gpu_tensor_ptrs,
int64_t num_layers,
int64_t gpu_kv_stride_in_bytes,
int64_t gpu_block_stride_in_bytes,
int64_t gpu_layer_stride_in_bytes)
: type(type),
gpu_tensor_ptrs(gpu_tensor_ptrs),
num_layers(num_layers),
gpu_kv_stride(gpu_kv_stride_in_bytes / sizeof(int64_t)),
gpu_block_stride(gpu_block_stride_in_bytes / sizeof(int64_t)),
gpu_layer_stride(gpu_layer_stride_in_bytes / sizeof(int64_t)) {}
};

// Template specialization for different backends
// Forward declaration
template<BackendType Type>
__device__ __host__ inline
int64_t* ptr_at(const GTensorHandler& handler,
int64_t layer_idx, int64_t kv_idx, int64_t block_idx);

// vLLM specialization
template<>
__device__ __host__ inline
int64_t* ptr_at<BackendType::VLLM>(const GTensorHandler& handler,
int64_t layer_idx, int64_t kv_idx, int64_t block_idx) {
return handler.gpu_tensor_ptrs[layer_idx] +
kv_idx * handler.gpu_kv_stride +
block_idx * handler.gpu_block_stride;
}

// TRT-LLM specialization
template<>
__device__ __host__ inline
int64_t* ptr_at<BackendType::TRTLLM>(const GTensorHandler& handler,
int64_t layer_idx, int64_t kv_idx, int64_t block_idx) {
return handler.gpu_tensor_ptrs[0] +
block_idx * handler.gpu_block_stride +
layer_idx * handler.gpu_layer_stride +
kv_idx * handler.gpu_kv_stride;
}

// SGLang specialization
template<>
__device__ __host__ inline
int64_t* ptr_at<BackendType::SGLANG>(const GTensorHandler& handler,
int64_t layer_idx, int64_t kv_idx, int64_t block_idx) {
return handler.gpu_tensor_ptrs[kv_idx * handler.num_layers + layer_idx] +
block_idx * handler.gpu_block_stride;
}

} // namespace flexkv
82 changes: 67 additions & 15 deletions csrc/tp_transfer_thread_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,67 @@ namespace flexkv {
TPTransferThreadGroup::TPTransferThreadGroup(
int num_gpus, const std::vector<std::vector<torch::Tensor>> &gpu_blocks,
torch::Tensor &cpu_blocks, int dp_group_id,
int num_layers,
torch::Tensor &gpu_kv_strides_tensor,
torch::Tensor &gpu_block_strides_tensor,
torch::Tensor &gpu_layer_strides_tensor,
torch::Tensor &gpu_chunk_sizes_tensor) {

num_gpus_ = num_gpus;

gpu_kv_strides_in_bytes_ = new int64_t[num_gpus];
gpu_block_strides_in_bytes_ = new int64_t[num_gpus];
gpu_layer_strides_in_bytes_ = new int64_t[num_gpus];
gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus];

int64_t* kv_strides_ptr = gpu_kv_strides_tensor.data_ptr<int64_t>();
int64_t* block_strides_ptr = gpu_block_strides_tensor.data_ptr<int64_t>();
int64_t* layer_strides_ptr = gpu_layer_strides_tensor.data_ptr<int64_t>();
int64_t* chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr<int64_t>();

for (int i = 0; i < num_gpus; i++) {
gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i];
gpu_block_strides_in_bytes_[i] = block_strides_ptr[i];
gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i];
gpu_layer_strides_in_bytes_[i] = layer_strides_ptr[i];
}

queues_.resize(num_gpus_);
mtxs_ = std::vector<std::mutex>(num_gpus_);
cvs_ = std::vector<std::condition_variable>(num_gpus_);

int num_layers = gpu_blocks[0].size();
num_tensors_per_gpu_ = gpu_blocks[0].size();
cudaMallocHost((void **)&gpu_blocks_,
num_gpus_ * num_layers * sizeof(void *));
num_gpus_ * num_tensors_per_gpu_ * sizeof(void *));
for (int i = 0; i < num_gpus_; ++i) {
for (int j = 0; j < num_layers; ++j) {
gpu_blocks_[i * num_layers + j] = gpu_blocks[i][j].data_ptr();
for (int j = 0; j < num_tensors_per_gpu_; ++j) {
gpu_blocks_[i * num_tensors_per_gpu_ + j] = gpu_blocks[i][j].data_ptr();
}
}

if (num_tensors_per_gpu_ == 1) {
backend_type_ = BackendType::TRTLLM;
} else if (num_tensors_per_gpu_ == num_layers) {
backend_type_ = BackendType::VLLM;
} else if (num_tensors_per_gpu_ == num_layers * 2) {
backend_type_ = BackendType::SGLANG;
} else {
throw std::runtime_error("Unsupported GPU block type: " + std::to_string(num_tensors_per_gpu_));
}

gpu_tensor_handlers_.reserve(num_gpus_);
for (int i = 0; i < num_gpus_; i++) {
int64_t **gpu_blocks_ptr = reinterpret_cast<int64_t**>(gpu_blocks_ + i * num_tensors_per_gpu_);
gpu_tensor_handlers_.emplace_back(
backend_type_,
gpu_blocks_ptr,
num_layers,
gpu_kv_strides_in_bytes_[i],
gpu_block_strides_in_bytes_[i],
gpu_layer_strides_in_bytes_[i]
);
}

cpu_blocks_ = cpu_blocks.data_ptr();

dp_group_id_ = dp_group_id;
Expand Down Expand Up @@ -95,8 +123,10 @@ TPTransferThreadGroup::~TPTransferThreadGroup() {

cudaFreeHost(gpu_blocks_);

gpu_tensor_handlers_.clear();
delete[] gpu_kv_strides_in_bytes_;
delete[] gpu_block_strides_in_bytes_;
delete[] gpu_layer_strides_in_bytes_;
delete[] gpu_chunk_sizes_in_bytes_;
}

Expand Down Expand Up @@ -140,8 +170,6 @@ void TPTransferThreadGroup::tp_group_transfer(
static_cast<int64_t *>(gpu_block_id_tensor.data_ptr());
int64_t *cpu_block_ids =
static_cast<int64_t *>(cpu_block_id_tensor.data_ptr());
void **gpu_layer_ptrs =
static_cast<void **>(gpu_blocks_ + i * num_layers + layer_id);
void *cpu_ptr = cpu_blocks_;
int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i];
if (is_mla && !is_host_to_device) {
Expand All @@ -156,15 +184,39 @@ void TPTransferThreadGroup::tp_group_transfer(
int64_t chunk_size = is_mla && !is_host_to_device ?
gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i];

flexkv::transfer_kv_blocks(
num_blocks, layer_id, layer_granularity, gpu_block_ids,
gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i],
gpu_startoff_inside_chunks,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes,
cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes,
cpu_startoff_inside_chunks, chunk_size, streams_[i],
transfer_sms, is_host_to_device, use_ce_transfer, is_mla
);
// Dispatch to the appropriate template based on backend type
switch (backend_type_) {
case BackendType::VLLM:
flexkv::transfer_kv_blocks<BackendType::VLLM>(
num_blocks, layer_id, layer_granularity, gpu_block_ids,
gpu_tensor_handlers_[i], gpu_startoff_inside_chunks,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes,
cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes,
cpu_startoff_inside_chunks, chunk_size, streams_[i],
transfer_sms, is_host_to_device, use_ce_transfer, is_mla
);
break;
case BackendType::TRTLLM:
flexkv::transfer_kv_blocks<BackendType::TRTLLM>(
num_blocks, layer_id, layer_granularity, gpu_block_ids,
gpu_tensor_handlers_[i], gpu_startoff_inside_chunks,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes,
cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes,
cpu_startoff_inside_chunks, chunk_size, streams_[i],
transfer_sms, is_host_to_device, use_ce_transfer, is_mla
);
break;
case BackendType::SGLANG:
flexkv::transfer_kv_blocks<BackendType::SGLANG>(
num_blocks, layer_id, layer_granularity, gpu_block_ids,
gpu_tensor_handlers_[i], gpu_startoff_inside_chunks,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes,
cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes,
cpu_startoff_inside_chunks, chunk_size, streams_[i],
transfer_sms, is_host_to_device, use_ce_transfer, is_mla
);
break;
}

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
Expand Down
14 changes: 13 additions & 1 deletion csrc/tp_transfer_thread_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,21 @@
#include <queue>
#include <functional>
#include <future>
#include <string>
#include "transfer.cuh"
#include "gtensor_handler.cuh"

namespace flexkv {

class TPTransferThreadGroup {
public:
TPTransferThreadGroup(
int num_gpus, const std::vector<std::vector<torch::Tensor>> &gpu_blocks,
torch::Tensor &cpu_blocks, int dp_group_id,
int num_layers,
torch::Tensor &gpu_kv_strides_tensor,
torch::Tensor &gpu_block_strides_tensor,
torch::Tensor &gpu_layer_strides_tensor,
torch::Tensor &gpu_chunk_sizes_tensor);
~TPTransferThreadGroup();

Expand All @@ -57,11 +64,16 @@ class TPTransferThreadGroup {
int dp_group_id_;
void **gpu_blocks_;
void *cpu_blocks_;

int num_tensors_per_gpu_;
int64_t *gpu_kv_strides_in_bytes_;
int64_t *gpu_block_strides_in_bytes_;
int64_t *gpu_layer_strides_in_bytes_;
int64_t *gpu_chunk_sizes_in_bytes_;

// Simplified: just one vector of handlers, runtime backend type selection
BackendType backend_type_;
std::vector<GTensorHandler> gpu_tensor_handlers_;

std::vector<std::thread> threads_;
std::vector<cudaStream_t> streams_;

Expand Down
Loading