diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 216ea365ec..03d468a1ac 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include #include @@ -26,31 +26,72 @@ namespace py = pybind11; void transfer_kv_blocks_binding( - torch::Tensor &gpu_block_id_tensor, torch::Tensor &gpu_layer_ptrs_tensor, - int64_t gpu_kv_stride_in_bytes, int64_t gpu_block_stride_in_bytes, + torch::Tensor &gpu_block_id_tensor, torch::Tensor &gpu_tensor_ptrs_tensor, + int64_t gpu_kv_stride_in_bytes, int64_t gpu_block_stride_in_bytes, int64_t gpu_layer_stride_in_bytes, torch::Tensor &cpu_block_id_tensor, torch::Tensor &cpu_tensor, int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes, - int start_layer_id, int transfer_sms = -1, bool is_host_to_device = true, - bool use_ce_transfer = false, bool is_mla = false) { + int start_layer_id, int num_layers, int transfer_sms = -1, bool is_host_to_device = true, + bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0) { int num_blocks = gpu_block_id_tensor.numel(); - int num_layers = gpu_layer_ptrs_tensor.numel(); int64_t *gpu_block_ids = static_cast(gpu_block_id_tensor.data_ptr()); - void **gpu_layer_ptrs = static_cast( - gpu_layer_ptrs_tensor.data_ptr()); // must be contiguous + void **gpu_tensor_ptrs = static_cast( + gpu_tensor_ptrs_tensor.data_ptr()); // must be contiguous int64_t *cpu_block_ids = static_cast(cpu_block_id_tensor.data_ptr()); void *cpu_ptr = static_cast(cpu_tensor.data_ptr()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - flexkv::transfer_kv_blocks( - num_blocks, start_layer_id, num_layers, gpu_block_ids, gpu_layer_ptrs, - gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, 0, cpu_block_ids, cpu_ptr, - cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, - cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, - is_host_to_device, use_ce_transfer, is_mla); + + // Determine backend type from gpu_block_type parameter + flexkv::BackendType backend_type; + if (gpu_block_type == 0) { + backend_type = flexkv::BackendType::VLLM; + } else if (gpu_block_type == 1) { + backend_type = flexkv::BackendType::TRTLLM; + } else if (gpu_block_type == 2) { + backend_type = flexkv::BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported gpu_block_type: " + std::to_string(gpu_block_type)); + } + + // Create GTensorHandler + flexkv::GTensorHandler handler( + backend_type, + reinterpret_cast(gpu_tensor_ptrs), + num_layers, + gpu_kv_stride_in_bytes, + gpu_block_stride_in_bytes, + gpu_layer_stride_in_bytes + ); + + // Dispatch to appropriate template instantiation + switch (backend_type) { + case flexkv::BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, + is_host_to_device, use_ce_transfer, is_mla); + break; + case flexkv::BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, + is_host_to_device, use_ce_transfer, is_mla); + break; + case flexkv::BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, + is_host_to_device, use_ce_transfer, is_mla); + break; + } + cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { throw std::runtime_error(cudaGetErrorString(err)); @@ -258,7 +299,16 @@ bool create_gds_file_binding(GDSManager& manager, PYBIND11_MODULE(c_ext, m) { m.def("transfer_kv_blocks", &transfer_kv_blocks_binding, - "Transfer multi-layer KV-cache between CPU and GPU"); + "Transfer multi-layer KV-cache between CPU and GPU", + py::arg("gpu_block_id_tensor"), py::arg("gpu_tensor_ptrs_tensor"), + py::arg("gpu_kv_stride_in_bytes"), py::arg("gpu_block_stride_in_bytes"), + py::arg("gpu_layer_stride_in_bytes"), py::arg("cpu_block_id_tensor"), + py::arg("cpu_tensor"), py::arg("cpu_kv_stride_in_bytes"), + py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), + py::arg("chunk_size_in_bytes"), py::arg("start_layer_id"), + py::arg("num_layers"), py::arg("transfer_sms") = -1, + py::arg("is_host_to_device") = true, py::arg("use_ce_transfer") = false, + py::arg("is_mla") = false, py::arg("gpu_block_type") = 0); m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding, "Transfer KV blocks between SSD and CPU memory", py::arg("ioctx"), py::arg("cpu_layer_id_list"), @@ -303,7 +353,11 @@ PYBIND11_MODULE(c_ext, m) { py::class_(m, "TPTransferThreadGroup") .def(py::init> &, - torch::Tensor &, int, torch::Tensor &, torch::Tensor &, torch::Tensor &>()) + torch::Tensor &, int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, torch::Tensor &>(), + py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), + py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("gpu_kv_strides_tensor"), py::arg("gpu_block_strides_tensor"), + py::arg("gpu_layer_strides_tensor"), py::arg("gpu_chunk_sizes_tensor")) .def("tp_group_transfer", &flexkv::TPTransferThreadGroup::tp_group_transfer, py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), diff --git a/csrc/gtensor_handler.cuh b/csrc/gtensor_handler.cuh new file mode 100644 index 0000000000..a2f32297ec --- /dev/null +++ b/csrc/gtensor_handler.cuh @@ -0,0 +1,83 @@ +#pragma once +#include + +namespace flexkv { + +// Backend type enumeration +enum class BackendType { + VLLM, + TRTLLM, + SGLANG +}; + +// Simplified GTensorHandler - no inheritance, just a data structure +struct GTensorHandler { + BackendType type; + int64_t **gpu_tensor_ptrs; + int64_t num_layers; + int64_t gpu_kv_stride; + int64_t gpu_block_stride; + int64_t gpu_layer_stride; + + __host__ __device__ + GTensorHandler() + : type(BackendType::VLLM), + gpu_tensor_ptrs(nullptr), + num_layers(0), + gpu_kv_stride(0), + gpu_block_stride(0), + gpu_layer_stride(0) {} + + __host__ __device__ + GTensorHandler(BackendType type, + int64_t **gpu_tensor_ptrs, + int64_t num_layers, + int64_t gpu_kv_stride_in_bytes, + int64_t gpu_block_stride_in_bytes, + int64_t gpu_layer_stride_in_bytes) + : type(type), + gpu_tensor_ptrs(gpu_tensor_ptrs), + num_layers(num_layers), + gpu_kv_stride(gpu_kv_stride_in_bytes / sizeof(int64_t)), + gpu_block_stride(gpu_block_stride_in_bytes / sizeof(int64_t)), + gpu_layer_stride(gpu_layer_stride_in_bytes / sizeof(int64_t)) {} +}; + +// Template specialization for different backends +// Forward declaration +template +__device__ __host__ inline +int64_t* ptr_at(const GTensorHandler& handler, + int64_t layer_idx, int64_t kv_idx, int64_t block_idx); + +// vLLM specialization +template<> +__device__ __host__ inline +int64_t* ptr_at(const GTensorHandler& handler, + int64_t layer_idx, int64_t kv_idx, int64_t block_idx) { + return handler.gpu_tensor_ptrs[layer_idx] + + kv_idx * handler.gpu_kv_stride + + block_idx * handler.gpu_block_stride; +} + +// TRT-LLM specialization +template<> +__device__ __host__ inline +int64_t* ptr_at(const GTensorHandler& handler, + int64_t layer_idx, int64_t kv_idx, int64_t block_idx) { + return handler.gpu_tensor_ptrs[0] + + block_idx * handler.gpu_block_stride + + layer_idx * handler.gpu_layer_stride + + kv_idx * handler.gpu_kv_stride; +} + +// SGLang specialization +template<> +__device__ __host__ inline +int64_t* ptr_at(const GTensorHandler& handler, + int64_t layer_idx, int64_t kv_idx, int64_t block_idx) { + return handler.gpu_tensor_ptrs[kv_idx * handler.num_layers + layer_idx] + + block_idx * handler.gpu_block_stride; +} + +} // namespace flexkv \ No newline at end of file diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index b108d4cd83..fdc4a9a762 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -23,39 +23,67 @@ namespace flexkv { TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, torch::Tensor &cpu_blocks, int dp_group_id, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, torch::Tensor &gpu_chunk_sizes_tensor) { num_gpus_ = num_gpus; gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_layer_strides_in_bytes_ = new int64_t[num_gpus]; gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; int64_t* kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); int64_t* block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t* layer_strides_ptr = gpu_layer_strides_tensor.data_ptr(); int64_t* chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); for (int i = 0; i < num_gpus; i++) { gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + gpu_layer_strides_in_bytes_[i] = layer_strides_ptr[i]; } queues_.resize(num_gpus_); mtxs_ = std::vector(num_gpus_); cvs_ = std::vector(num_gpus_); - int num_layers = gpu_blocks[0].size(); + num_tensors_per_gpu_ = gpu_blocks[0].size(); cudaMallocHost((void **)&gpu_blocks_, - num_gpus_ * num_layers * sizeof(void *)); + num_gpus_ * num_tensors_per_gpu_ * sizeof(void *)); for (int i = 0; i < num_gpus_; ++i) { - for (int j = 0; j < num_layers; ++j) { - gpu_blocks_[i * num_layers + j] = gpu_blocks[i][j].data_ptr(); + for (int j = 0; j < num_tensors_per_gpu_; ++j) { + gpu_blocks_[i * num_tensors_per_gpu_ + j] = gpu_blocks[i][j].data_ptr(); } } + if (num_tensors_per_gpu_ == 1) { + backend_type_ = BackendType::TRTLLM; + } else if (num_tensors_per_gpu_ == num_layers) { + backend_type_ = BackendType::VLLM; + } else if (num_tensors_per_gpu_ == num_layers * 2) { + backend_type_ = BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported GPU block type: " + std::to_string(num_tensors_per_gpu_)); + } + + gpu_tensor_handlers_.reserve(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { + int64_t **gpu_blocks_ptr = reinterpret_cast(gpu_blocks_ + i * num_tensors_per_gpu_); + gpu_tensor_handlers_.emplace_back( + backend_type_, + gpu_blocks_ptr, + num_layers, + gpu_kv_strides_in_bytes_[i], + gpu_block_strides_in_bytes_[i], + gpu_layer_strides_in_bytes_[i] + ); + } + cpu_blocks_ = cpu_blocks.data_ptr(); dp_group_id_ = dp_group_id; @@ -95,8 +123,10 @@ TPTransferThreadGroup::~TPTransferThreadGroup() { cudaFreeHost(gpu_blocks_); + gpu_tensor_handlers_.clear(); delete[] gpu_kv_strides_in_bytes_; delete[] gpu_block_strides_in_bytes_; + delete[] gpu_layer_strides_in_bytes_; delete[] gpu_chunk_sizes_in_bytes_; } @@ -140,8 +170,6 @@ void TPTransferThreadGroup::tp_group_transfer( static_cast(gpu_block_id_tensor.data_ptr()); int64_t *cpu_block_ids = static_cast(cpu_block_id_tensor.data_ptr()); - void **gpu_layer_ptrs = - static_cast(gpu_blocks_ + i * num_layers + layer_id); void *cpu_ptr = cpu_blocks_; int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; if (is_mla && !is_host_to_device) { @@ -156,15 +184,39 @@ void TPTransferThreadGroup::tp_group_transfer( int64_t chunk_size = is_mla && !is_host_to_device ? gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i]; - flexkv::transfer_kv_blocks( - num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i], - gpu_startoff_inside_chunks, - cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, - cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, chunk_size, streams_[i], - transfer_sms, is_host_to_device, use_ce_transfer, is_mla - ); + // Dispatch to the appropriate template based on backend type + switch (backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, chunk_size, streams_[i], + transfer_sms, is_host_to_device, use_ce_transfer, is_mla + ); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, chunk_size, streams_[i], + transfer_sms, is_host_to_device, use_ce_transfer, is_mla + ); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, chunk_size, streams_[i], + transfer_sms, is_host_to_device, use_ce_transfer, is_mla + ); + break; + } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 3d57e569c5..6a53d79a1d 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -28,14 +28,21 @@ #include #include #include +#include +#include "transfer.cuh" +#include "gtensor_handler.cuh" + namespace flexkv { + class TPTransferThreadGroup { public: TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, torch::Tensor &cpu_blocks, int dp_group_id, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, torch::Tensor &gpu_chunk_sizes_tensor); ~TPTransferThreadGroup(); @@ -57,11 +64,16 @@ class TPTransferThreadGroup { int dp_group_id_; void **gpu_blocks_; void *cpu_blocks_; - + int num_tensors_per_gpu_; int64_t *gpu_kv_strides_in_bytes_; int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_layer_strides_in_bytes_; int64_t *gpu_chunk_sizes_in_bytes_; + // Simplified: just one vector of handlers, runtime backend type selection + BackendType backend_type_; + std::vector gpu_tensor_handlers_; + std::vector threads_; std::vector streams_; diff --git a/csrc/transfer.cu b/csrc/transfer.cu index 938d227935..9dda5d4fe3 100644 --- a/csrc/transfer.cu +++ b/csrc/transfer.cu @@ -23,14 +23,18 @@ namespace flexkv { #define FLOAT4_PTR(ptr) reinterpret_cast(ptr) +// Templated CUDA kernel - backend type determined at compile time +template __global__ void transfer_kv_blocks_kernel( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, - int64_t **gpu_layer_ptrs, int64_t gpu_kv_stride, int64_t gpu_block_stride, + GTensorHandler gpu_handler, int64_t gpu_startoff_inside_chunks, int64_t *cpu_block_ids, int64_t *cpu_ptr, int64_t cpu_kv_stride, int64_t cpu_layer_stride, int64_t cpu_block_stride, int64_t cpu_startoff_inside_chunks, int64_t copy_size, bool is_mla, - bool is_host_to_device) { + bool is_host_to_device) { + // start layer id should also be provided for gpu location calculation + // but for now, we only support full-layer transfer, so start_layer_id is always 0 int kv_dim = is_mla ? 1 : 2; int num_chunks = num_layers * kv_dim * num_blocks; int64_t copy_size_in_float4 = copy_size * sizeof(int64_t) / sizeof(float4); @@ -46,10 +50,10 @@ __global__ void transfer_kv_blocks_kernel( cpu_ptr + (layer_idx + start_layer_id) * cpu_layer_stride + kv_idx * cpu_kv_stride + cpu_block_idx * cpu_block_stride + cpu_startoff_inside_chunks; - int64_t *gpu_chunk_ptr = gpu_layer_ptrs[layer_idx] + - kv_idx * gpu_kv_stride + - gpu_block_idx * gpu_block_stride + - gpu_startoff_inside_chunks; + + // Use template specialization to compute gpu pointer + int64_t *gpu_ptr = ptr_at(gpu_handler, layer_idx, kv_idx, gpu_block_idx); + int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks; int64_t *src_chunk_ptr = is_host_to_device ? cpu_chunk_ptr : gpu_chunk_ptr; int64_t *dst_chunk_ptr = is_host_to_device ? gpu_chunk_ptr : cpu_chunk_ptr; @@ -62,20 +66,23 @@ __global__ void transfer_kv_blocks_kernel( } } +// Templated host function +template void transfer_kv_blocks( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, - void **gpu_layer_ptrs, int64_t gpu_kv_stride_in_bytes, - int64_t gpu_block_stride_in_bytes, int64_t gpu_startoff_inside_chunks, + GTensorHandler gpu_tensor_handler, + int64_t gpu_startoff_inside_chunks, int64_t *cpu_block_ids, void *cpu_ptr, int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_sms, bool is_host_to_device, bool use_ce_transfer, bool is_mla) { + int block_size = 128; static int max_blocks_per_sm = -1; if (max_blocks_per_sm == -1) { cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_blocks_per_sm, transfer_kv_blocks_kernel, block_size, 0); + &max_blocks_per_sm, transfer_kv_blocks_kernel, block_size, 0); } if (transfer_sms == -1) { @@ -84,11 +91,8 @@ void transfer_kv_blocks( int block_count = transfer_sms * max_blocks_per_sm; - int64_t **gpu_layer_ptrs_int64 = reinterpret_cast(gpu_layer_ptrs); int64_t *cpu_ptr_int64 = reinterpret_cast(cpu_ptr); - int64_t gpu_kv_stride_int64 = gpu_kv_stride_in_bytes / sizeof(int64_t); int64_t cpu_kv_stride_int64 = cpu_kv_stride_in_bytes / sizeof(int64_t); - int64_t gpu_block_stride_int64 = gpu_block_stride_in_bytes / sizeof(int64_t); int64_t cpu_block_stride_int64 = cpu_block_stride_in_bytes / sizeof(int64_t); int64_t cpu_layer_stride_int64 = cpu_layer_stride_in_bytes / sizeof(int64_t); int64_t cpu_startoff_inside_chunks_int64 = @@ -99,20 +103,23 @@ void transfer_kv_blocks( dim3 blockDim(block_size); dim3 gridDim(block_count); + + // CE transfer mode (Copy Engine using cudaMemcpyAsync) if (use_ce_transfer) { + int kv_dim = is_mla ? 1 : 2; for (int i = 0; i < num_layers; i++) { - int kv_dim = is_mla ? 1 : 2; for (int j = 0; j < kv_dim; j++) { for (int k = 0; k < num_blocks; k++) { int64_t gpu_block_idx = gpu_block_ids[k]; int64_t cpu_block_idx = cpu_block_ids[k]; + int64_t *cpu_chunk_ptr = cpu_ptr_int64 + (i + start_layer_id) * cpu_layer_stride_int64 + j * cpu_kv_stride_int64 + cpu_block_idx * cpu_block_stride_int64 + cpu_startoff_inside_chunks_int64; - int64_t *gpu_chunk_ptr = gpu_layer_ptrs_int64[i] + - j * gpu_kv_stride_int64 + - gpu_block_idx * gpu_block_stride_int64 + + + int64_t *gpu_ptr = ptr_at(gpu_tensor_handler, i, j, gpu_block_idx); + int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks_int64; if (is_host_to_device) { @@ -126,10 +133,10 @@ void transfer_kv_blocks( } } } else { - transfer_kv_blocks_kernel<<>>( + // Custom kernel transfer + transfer_kv_blocks_kernel<<>>( num_blocks, start_layer_id, num_layers, gpu_block_ids, - gpu_layer_ptrs_int64, gpu_kv_stride_int64, gpu_block_stride_int64, - gpu_startoff_inside_chunks_int64, + gpu_tensor_handler, gpu_startoff_inside_chunks_int64, cpu_block_ids, cpu_ptr_int64, cpu_kv_stride_int64, cpu_layer_stride_int64, cpu_block_stride_int64, cpu_startoff_inside_chunks_int64, chunk_size_in_int64, is_mla, @@ -138,4 +145,17 @@ void transfer_kv_blocks( cudaStreamSynchronize(stream); } +// Explicit template instantiations +template void transfer_kv_blocks( + int, int, int, int64_t*, GTensorHandler, int64_t, int64_t*, void*, + int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, bool); + +template void transfer_kv_blocks( + int, int, int, int64_t*, GTensorHandler, int64_t, int64_t*, void*, + int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, bool); + +template void transfer_kv_blocks( + int, int, int, int64_t*, GTensorHandler, int64_t, int64_t*, void*, + int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, bool); + } // namespace flexkv diff --git a/csrc/transfer.cuh b/csrc/transfer.cuh index ad46fc5be1..5e834be660 100644 --- a/csrc/transfer.cuh +++ b/csrc/transfer.cuh @@ -17,13 +17,16 @@ #pragma once #include +#include "gtensor_handler.cuh" namespace flexkv { +// Template function for transfer, specialized for each backend type +template void transfer_kv_blocks( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, - void **gpu_layer_ptrs, int64_t gpu_kv_stride_in_bytes, - int64_t gpu_block_stride_in_bytes, int64_t gpu_startoff_inside_chunks, + GTensorHandler gpu_tensor_handler, // Pass by value! + int64_t gpu_startoff_inside_chunks, int64_t *cpu_block_ids, void *cpu_ptr, int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index b5e0fb84b5..f192e70fce 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -285,7 +285,7 @@ def __init__(self, self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] # Get pointers first self.gpu_blocks_ptrs = self._get_layer_ptrs(self.gpu_blocks) - self.gpu_layer_ptrs = self.gpu_blocks_ptrs + self.gpu_tensor_ptrs = self.gpu_blocks_ptrs self.cpu_tensor = cpu_blocks @@ -299,13 +299,20 @@ def __init__(self, self.chunk_size_in_bytes = gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize self.gpu_kv_stride_in_bytes = gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize self.gpu_block_stride_in_bytes = gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize + self.gpu_layer_stride_in_bytes = gpu_kv_layout_per_layer.get_layer_stride() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize - if not gpu_kv_layout.type == KVCacheLayoutType.LAYERWISE: - raise ValueError("Only layerwise layout is supported for GPU") + if len(self.gpu_blocks) == 1: + self.gpu_block_type_ = 1 + elif len(self.gpu_blocks) == self.num_layers: + self.gpu_block_type_ = 0 + elif len(self.gpu_blocks) == self.num_layers * 2: + self.gpu_block_type_ = 2 + else: + raise ValueError(f"Invalid GPU block type: {len(self.gpu_blocks)}") # set GPU device if gpu_device_id != -1: torch.cuda.set_device(gpu_device_id) @@ -315,6 +322,18 @@ def __init__(self, self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h + print(f"GPU block type: {self.gpu_block_type_}") + print(f"GPU blocks pointers: {self.gpu_blocks_ptrs}") + print(f"GPU tensor pointers: {self.gpu_tensor_ptrs}") + print(f"chunk size: {self.chunk_size_in_bytes}") + print(f"gpu kv stride: {self.gpu_kv_stride_in_bytes}") + print(f"gpu block stride: {self.gpu_block_stride_in_bytes}") + print(f"gpu layer stride: {self.gpu_layer_stride_in_bytes}") + print(f"cpu layer stride: {self.cpu_layer_stride_in_bytes}") + print(f"cpu kv stride: {self.cpu_kv_stride_in_bytes}") + print(f"cpu block stride: {self.cpu_block_stride_in_bytes}") + print(f"num layers: {self.num_layers}") + def _transfer_impl( self, src_block_ids: torch.Tensor, @@ -346,15 +365,14 @@ def _transfer_impl( if len(gpu_block_id_list) == 0: return - layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) - - gpu_layer_ptrs = self.gpu_layer_ptrs[layer_id_list].contiguous().pin_memory() + gpu_tensor_ptrs = self.gpu_blocks_ptrs.contiguous().pin_memory() transfer_kv_blocks( gpu_block_id_list, - gpu_layer_ptrs, + gpu_tensor_ptrs, self.gpu_kv_stride_in_bytes, self.gpu_block_stride_in_bytes, + self.gpu_layer_stride_in_bytes, cpu_block_id_list, self.cpu_tensor, self.cpu_kv_stride_in_bytes, @@ -362,10 +380,12 @@ def _transfer_impl( self.cpu_block_stride_in_bytes, self.chunk_size_in_bytes, layer_id, + layer_granularity, transfer_sms, transfer_type == TransferType.H2D, use_ce_transfer, self.is_mla, + self.gpu_block_type_, ) def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: @@ -427,7 +447,7 @@ def __init__(self, blocks_in_one_gpu.append(handle.get_tensor()) imported_gpu_blocks.append(blocks_in_one_gpu) self.gpu_blocks = imported_gpu_blocks - self.dtype = dtype + self.dtype = dtype # note this should be quantized data type self.is_mla = gpu_kv_layouts[0].is_mla self.num_gpus = len(self.gpu_blocks) @@ -437,23 +457,22 @@ def __init__(self, cudaHostRegister(cpu_blocks) self.num_layers = gpu_kv_layouts[0].num_layer - gpu_kv_layouts_per_layer = [gpu_kv_layout.div_layer(self.num_layers) for gpu_kv_layout in gpu_kv_layouts] - - self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize \ - for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] - self.gpu_kv_strides_in_bytes = [gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize \ - for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] - self.gpu_block_strides_in_bytes = [gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize \ - for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + + # here the chunk size doesn't include the layer info + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_block_strides_in_bytes = [gpu_kv_layout.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_layer_strides_in_bytes = [gpu_kv_layout.get_layer_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize - if not gpu_kv_layouts[0].type == KVCacheLayoutType.LAYERWISE: - raise ValueError("Only layerwise layout is supported for GPU") - self.transfer_sms_h2d = transfer_sms_h2d self.transfer_sms_d2h = transfer_sms_d2h self.use_ce_transfer_h2d = use_ce_transfer_h2d @@ -462,10 +481,10 @@ def __init__(self, gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) - + gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, - gpu_kv_strides_tensor, gpu_block_strides_tensor, - gpu_chunk_sizes_tensor) + self.num_layers, gpu_kv_strides_tensor, + gpu_block_strides_tensor, gpu_layer_strides_tensor, gpu_chunk_sizes_tensor) def _transfer_impl(self, diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index 2e926ce95a..473cf0741c 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -24,20 +24,39 @@ create_gpu_kv_layout, GPUKVCacheVerifier ) -def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config, num_gpu_blocks, child_conn): +def run_tp_client(dp_client_id, + tp_rank, + server_recv_port, + model_config, + cache_config, + num_gpu_blocks, + child_conn, + gpu_layout_type): """Run tp_client process""" try: device_id = tp_rank + dp_client_id * model_config.tp_size tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) - gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) # Create GPU blocks for this tp_rank in the tp_client process gpu_blocks_for_tp = [] - for _ in range(model_config.num_layers): + if gpu_layout_type == 0: + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + elif gpu_layout_type == 1: gpu_blocks_for_tp.append( - torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + torch.empty(size=tuple(gpu_kv_layout.kv_shape[:]), dtype=model_config.dtype).cuda(device_id) ) + elif gpu_layout_type == 2: + for _ in range(model_config.num_layers * 2): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[2:]), dtype=model_config.dtype).cuda(device_id) + ) + else: + raise ValueError(f"Invalid GPU layout type: {gpu_layout_type}") tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) # Send GPU blocks back to main process via pipe if connection provided @@ -93,7 +112,12 @@ def shutdown_tp_client(tp_client_processes): KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE, ]) -def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type): +@pytest.mark.parametrize("gpu_layout_type", [ + 0, + 1, + 2, +]) +def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type, gpu_layout_type): tp_size = model_config.tp_size dp_size = model_config.dp_size @@ -121,6 +145,9 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) # Skip tests based on GPU availability and configuration skip_if_insufficient_gpus(tp_size * dp_size) + if enable_gds and os.environ.get("FLEXKV_GDS_TEST", "0") == "0": + pytest.skip("skip because GDS test is not enabled") + if enable_remote: pytest.skip("skip because enable_remote is not supported") @@ -145,7 +172,7 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) tp_client_process = mp_ctx.Process( target=run_tp_client, - args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks + tp_rank, child_conn), + args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks + tp_rank, child_conn, gpu_layout_type), daemon=True ) tp_client_processes.append(tp_client_process) @@ -172,14 +199,15 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) print(f"[Main Process] Creating GPUKVCacheVerifier with GPU blocks from {len(all_gpu_blocks)} TP clients") # Get gpu_kv_layout from cache_config for GPUKVCacheVerifier - gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) gpu_kv_verifier = GPUKVCacheVerifier( shared_gpu_blocks=all_gpu_blocks, gpu_kv_layout=gpu_kv_layout, tp_size=model_config.tp_size, tokens_per_block=cache_config.tokens_per_block, - dtype=model_config.dtype + dtype=model_config.dtype, + gpu_layout_type=gpu_layout_type ) print("[Main Process] GPUKVCacheVerifier created successfully") else: diff --git a/tests/test_utils.py b/tests/test_utils.py index 8a5a9241f2..86a64ef505 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -110,7 +110,7 @@ def block_ids_2_slot_mapping(block_ids, tokens_per_block, actual_length=-1): actual_length = len(block_ids) * tokens_per_block return slot_mapping[:actual_length] -def create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks): +def create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type = 0): """Create GPU KV layout""" num_layers = model_config.num_layers num_kv_heads = model_config.num_kv_heads @@ -119,8 +119,14 @@ def create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks): tp_size = model_config.tp_size tokens_per_block = cache_config.tokens_per_block + if gpu_layout_type == 0 or gpu_layout_type == 2: + layout_type = KVCacheLayoutType.LAYERWISE + elif gpu_layout_type == 1: + layout_type = KVCacheLayoutType.BLOCKWISE + else: + raise ValueError(f"Invalid GPU layout type: {gpu_layout_type}") tpgroup_gpu_kv_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=layout_type, num_layer=num_layers, num_block=num_gpu_blocks, tokens_per_block=tokens_per_block, @@ -456,9 +462,11 @@ def __init__(self, gpu_kv_layout: KVCacheLayout, tp_size: int, tokens_per_block: int, - dtype: torch.dtype)->None: + dtype: torch.dtype, + gpu_layout_type: int)->None: self.gpu_kv_layout = gpu_kv_layout self.num_layers = gpu_kv_layout.num_layer + self.gpu_layout_type = gpu_layout_type # we have to map the exported gpu blocks into the virtual space of current process if isinstance(shared_gpu_blocks[0], torch.Tensor): self.gpu_blocks = shared_gpu_blocks @@ -509,12 +517,12 @@ def fill_gpu_blocks(self, token_ids, block_ids): kv_num = 2 if not self.is_mla else 1 for kv_id in range(kv_num): for tp_id in range(self.tp_size): - if isinstance(self.gpu_blocks[0], list): - # multiple gpu:gpu_blocks[tp_id][layer_id] + if self.gpu_layout_type == 0: gpu_tensor = self.gpu_blocks[tp_id][layer_id] - else: - # single gpu:gpu_blocks[layer_id] - gpu_tensor = self.gpu_blocks[layer_id] + elif self.gpu_layout_type == 1: + gpu_tensor = self.gpu_blocks[tp_id][0] + elif self.gpu_layout_type == 2: + gpu_tensor = self.gpu_blocks[tp_id][layer_id + self.num_layers * kv_id] for head_id in range(self.gpu_kv_layout.num_head): actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id @@ -527,7 +535,17 @@ def fill_gpu_blocks(self, token_ids, block_ids): token_ids[start_token_idx:end_token_idx], actual_head_id) # GPU tensor dim:[kv_dim, num_block, tokens_per_block, num_head, head_size] - gpu_tensor[kv_id, block_id, :, head_id, :] = hash_value + if self.gpu_layout_type == 0: + # gpu_layout_type 0: [num_layer][kv_dim, num_block, tokens_per_block, num_head, head_size] + gpu_tensor[kv_id, block_id, :, head_id, :] = hash_value + elif self.gpu_layout_type == 1: + # gpu_layout_type 1: [tp_id][0][num_block, num_layer, kv_dim, tokens_per_block, num_head, head_size] + # Need to get the first (and only) tensor from the list + gpu_tensor[block_id, layer_id, kv_id, :, head_id, :] = hash_value + elif self.gpu_layout_type == 2: + gpu_tensor[block_id, :, head_id, :] = hash_value + else: + raise ValueError(f"Invalid GPU layout type: {self.gpu_layout_type}") def verify_kv_blocks(self, token_ids, block_ids)->bool: assert len(token_ids) == len(block_ids) * self.tokens_per_block @@ -544,10 +562,13 @@ def verify_kv_blocks(self, token_ids, block_ids)->bool: kv_num = 2 if not self.is_mla else 1 for kv_id in range(kv_num): for tp_id in range(self.tp_size): - if isinstance(self.gpu_blocks[0], list): + if self.gpu_layout_type == 0: + #if isinstance(self.gpu_blocks[0], list): gpu_tensor = self.gpu_blocks[tp_id][layer_id] - else: - gpu_tensor = self.gpu_blocks[layer_id] + elif self.gpu_layout_type == 1: + gpu_tensor = self.gpu_blocks[tp_id][0] + elif self.gpu_layout_type == 2: + gpu_tensor = self.gpu_blocks[tp_id][layer_id + self.num_layers * kv_id] for head_id in range(self.gpu_kv_layout.num_head): actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id @@ -557,8 +578,17 @@ def verify_kv_blocks(self, token_ids, block_ids)->bool: expected_hash_value = self.hash_all_values(layer_id, kv_id, token_ids[start_token_idx:end_token_idx], actual_head_id) - - actual_values = gpu_tensor[kv_id, block_id, :, head_id, :] + if self.gpu_layout_type == 0: + # gpu_layout_type 0: [num_layer][kv_dim, num_block, tokens_per_block, num_head, head_size] + actual_values = gpu_tensor[kv_id, block_id, :, head_id, :] + elif self.gpu_layout_type == 1: + # gpu_layout_type 1: [tp_id][0][num_block, num_layer, kv_dim, tokens_per_block, num_head, head_size] + # Need to get the first (and only) tensor from the list + actual_values = gpu_tensor[block_id, layer_id, kv_id, :, head_id, :] + elif self.gpu_layout_type == 2: + actual_values = gpu_tensor[block_id, :, head_id, :] + else: + raise ValueError(f"Invalid GPU layout type: {self.gpu_layout_type}") if not torch.allclose(actual_values, torch.full_like(actual_values, expected_hash_value),