diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index c03ad74c40..c87f720238 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -45,7 +45,7 @@ void transfer_kv_blocks_binding( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); flexkv::transfer_kv_blocks( num_blocks, start_layer_id, num_layers, gpu_block_ids, gpu_layer_ptrs, - gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, cpu_block_ids, cpu_ptr, + gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, 0, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, is_host_to_device, use_ce_transfer, is_mla); diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 06cb45c4fe..b108d4cd83 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -143,15 +143,26 @@ void TPTransferThreadGroup::tp_group_transfer( void **gpu_layer_ptrs = static_cast(gpu_blocks_ + i * num_layers + layer_id); void *cpu_ptr = cpu_blocks_; - int64_t cpu_startoff_inside_chunks = - is_mla ? 0 : i * gpu_chunk_sizes_in_bytes_[i]; + int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; + if (is_mla && !is_host_to_device) { + cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_; + } else if (is_mla && is_host_to_device) { + cpu_startoff_inside_chunks = 0; + } + int64_t gpu_startoff_inside_chunks = + is_mla && !is_host_to_device ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : 0; + // we assume that the chunk size is the same for all gpus, + // even if they have different number of gpu_blocks + int64_t chunk_size = is_mla && !is_host_to_device ? + gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i]; flexkv::transfer_kv_blocks( num_blocks, layer_id, layer_granularity, gpu_block_ids, gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i], + gpu_startoff_inside_chunks, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, gpu_chunk_sizes_in_bytes_[i], streams_[i], + cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, is_host_to_device, use_ce_transfer, is_mla ); diff --git a/csrc/transfer.cu b/csrc/transfer.cu index 73859984ee..938d227935 100644 --- a/csrc/transfer.cu +++ b/csrc/transfer.cu @@ -26,6 +26,7 @@ namespace flexkv { __global__ void transfer_kv_blocks_kernel( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, int64_t **gpu_layer_ptrs, int64_t gpu_kv_stride, int64_t gpu_block_stride, + int64_t gpu_startoff_inside_chunks, int64_t *cpu_block_ids, int64_t *cpu_ptr, int64_t cpu_kv_stride, int64_t cpu_layer_stride, int64_t cpu_block_stride, int64_t cpu_startoff_inside_chunks, int64_t copy_size, bool is_mla, @@ -47,7 +48,8 @@ __global__ void transfer_kv_blocks_kernel( cpu_startoff_inside_chunks; int64_t *gpu_chunk_ptr = gpu_layer_ptrs[layer_idx] + kv_idx * gpu_kv_stride + - gpu_block_idx * gpu_block_stride; + gpu_block_idx * gpu_block_stride + + gpu_startoff_inside_chunks; int64_t *src_chunk_ptr = is_host_to_device ? cpu_chunk_ptr : gpu_chunk_ptr; int64_t *dst_chunk_ptr = is_host_to_device ? gpu_chunk_ptr : cpu_chunk_ptr; @@ -63,7 +65,8 @@ __global__ void transfer_kv_blocks_kernel( void transfer_kv_blocks( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, void **gpu_layer_ptrs, int64_t gpu_kv_stride_in_bytes, - int64_t gpu_block_stride_in_bytes, int64_t *cpu_block_ids, void *cpu_ptr, + int64_t gpu_block_stride_in_bytes, int64_t gpu_startoff_inside_chunks, + int64_t *cpu_block_ids, void *cpu_ptr, int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_sms, @@ -90,6 +93,8 @@ void transfer_kv_blocks( int64_t cpu_layer_stride_int64 = cpu_layer_stride_in_bytes / sizeof(int64_t); int64_t cpu_startoff_inside_chunks_int64 = cpu_startoff_inside_chunks / sizeof(int64_t); + int64_t gpu_startoff_inside_chunks_int64 = + gpu_startoff_inside_chunks / sizeof(int64_t); int64_t chunk_size_in_int64 = chunk_size_in_bytes / sizeof(int64_t); dim3 blockDim(block_size); @@ -107,7 +112,8 @@ void transfer_kv_blocks( cpu_startoff_inside_chunks_int64; int64_t *gpu_chunk_ptr = gpu_layer_ptrs_int64[i] + j * gpu_kv_stride_int64 + - gpu_block_idx * gpu_block_stride_int64; + gpu_block_idx * gpu_block_stride_int64 + + gpu_startoff_inside_chunks_int64; if (is_host_to_device) { cudaMemcpyAsync(gpu_chunk_ptr, cpu_chunk_ptr, chunk_size_in_bytes, @@ -123,6 +129,7 @@ void transfer_kv_blocks( transfer_kv_blocks_kernel<<>>( num_blocks, start_layer_id, num_layers, gpu_block_ids, gpu_layer_ptrs_int64, gpu_kv_stride_int64, gpu_block_stride_int64, + gpu_startoff_inside_chunks_int64, cpu_block_ids, cpu_ptr_int64, cpu_kv_stride_int64, cpu_layer_stride_int64, cpu_block_stride_int64, cpu_startoff_inside_chunks_int64, chunk_size_in_int64, is_mla, diff --git a/csrc/transfer.cuh b/csrc/transfer.cuh index ba9b171ca5..ad46fc5be1 100644 --- a/csrc/transfer.cuh +++ b/csrc/transfer.cuh @@ -23,7 +23,8 @@ namespace flexkv { void transfer_kv_blocks( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, void **gpu_layer_ptrs, int64_t gpu_kv_stride_in_bytes, - int64_t gpu_block_stride_in_bytes, int64_t *cpu_block_ids, void *cpu_ptr, + int64_t gpu_block_stride_in_bytes, int64_t gpu_startoff_inside_chunks, + int64_t *cpu_block_ids, void *cpu_ptr, int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_sms,