|
15 | 15 | #include "helper.h" |
16 | 16 | #include "paddle/extension.h" |
17 | 17 |
|
18 | | -// #define SWAP_DEBUG |
| 18 | +// D2H: Each thread block handles ALL layers for one swap block. |
| 19 | +// This produces perfectly contiguous host writes (1 block × all layers), |
| 20 | +// maximizing write-combining efficiency. |
| 21 | +template <typename T> |
| 22 | +__global__ void swap_d2h_kernel(T** __restrict__ layer_ptrs, |
| 23 | + T* __restrict__ cpu_buffer, |
| 24 | + const int64_t* __restrict__ gpu_block_ids, |
| 25 | + int n_blocks, |
| 26 | + int layer_num, |
| 27 | + int64_t block_stride) { |
| 28 | + int block_idx = blockIdx.x; |
| 29 | + if (block_idx >= n_blocks) return; |
| 30 | + |
| 31 | + int64_t gpu_block = gpu_block_ids[block_idx]; |
| 32 | + int64_t num_vec_per_layer = (block_stride * sizeof(T)) / sizeof(float4); |
| 33 | + |
| 34 | + T* dst_base = cpu_buffer + (int64_t)block_idx * layer_num * block_stride; |
| 35 | + |
| 36 | + for (int layer_idx = 0; layer_idx < layer_num; layer_idx++) { |
| 37 | + const T* src = layer_ptrs[layer_idx] + gpu_block * block_stride; |
| 38 | + float4* dst4 = |
| 39 | + reinterpret_cast<float4*>(dst_base + layer_idx * block_stride); |
| 40 | + const float4* src4 = reinterpret_cast<const float4*>(src); |
| 41 | + |
| 42 | + for (int64_t i = threadIdx.x; i < num_vec_per_layer; i += blockDim.x) { |
| 43 | + dst4[i] = src4[i]; |
| 44 | + } |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +// H2D: scatter from contiguous staging buffer to scattered GPU layer tensors |
| 49 | +template <typename T> |
| 50 | +__global__ void scatter_blocks_kernel(T** __restrict__ layer_ptrs, |
| 51 | + const T* __restrict__ staging, |
| 52 | + const int64_t* __restrict__ gpu_block_ids, |
| 53 | + int n_blocks, |
| 54 | + int layer_num, |
| 55 | + int64_t block_stride) { |
| 56 | + int pair_idx = blockIdx.x; |
| 57 | + int block_idx = pair_idx / layer_num; |
| 58 | + int layer_idx = pair_idx % layer_num; |
| 59 | + |
| 60 | + if (block_idx >= n_blocks) return; |
| 61 | + |
| 62 | + int64_t gpu_block = gpu_block_ids[block_idx]; |
| 63 | + const T* src = staging + (int64_t)block_idx * layer_num * block_stride + |
| 64 | + layer_idx * block_stride; |
| 65 | + T* dst = layer_ptrs[layer_idx] + gpu_block * block_stride; |
| 66 | + |
| 67 | + int64_t num_vec = (block_stride * sizeof(T)) / sizeof(float4); |
| 68 | + const float4* src4 = reinterpret_cast<const float4*>(src); |
| 69 | + float4* dst4 = reinterpret_cast<float4*>(dst); |
| 70 | + |
| 71 | + for (int64_t i = threadIdx.x; i < num_vec; i += blockDim.x) { |
| 72 | + dst4[i] = src4[i]; |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +static void* g_staging_buffer = nullptr; |
| 77 | +static size_t g_staging_buffer_size = 0; |
| 78 | +static void* g_device_block_ids = nullptr; |
| 79 | +static size_t g_device_block_ids_size = 0; |
| 80 | +static void* g_device_layer_ptrs = nullptr; |
| 81 | +static size_t g_device_layer_ptrs_size = 0; |
| 82 | + |
| 83 | +static void ensure_staging_buffer(size_t required_size) { |
| 84 | + if (g_staging_buffer_size < required_size) { |
| 85 | + if (g_staging_buffer) cudaFree(g_staging_buffer); |
| 86 | + cudaError_t err = cudaMalloc(&g_staging_buffer, required_size); |
| 87 | + PADDLE_ENFORCE_EQ( |
| 88 | + err, |
| 89 | + cudaSuccess, |
| 90 | + phi::errors::External("cudaMalloc staging buffer failed: %s", |
| 91 | + cudaGetErrorString(err))); |
| 92 | + g_staging_buffer_size = required_size; |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +static void ensure_device_block_ids(size_t required_size) { |
| 97 | + if (g_device_block_ids_size < required_size) { |
| 98 | + if (g_device_block_ids) cudaFree(g_device_block_ids); |
| 99 | + cudaError_t err = cudaMalloc(&g_device_block_ids, required_size); |
| 100 | + PADDLE_ENFORCE_EQ( |
| 101 | + err, |
| 102 | + cudaSuccess, |
| 103 | + phi::errors::External("cudaMalloc device block_ids failed: %s", |
| 104 | + cudaGetErrorString(err))); |
| 105 | + g_device_block_ids_size = required_size; |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +static void ensure_device_layer_ptrs(size_t required_size) { |
| 110 | + if (g_device_layer_ptrs_size < required_size) { |
| 111 | + if (g_device_layer_ptrs) cudaFree(g_device_layer_ptrs); |
| 112 | + cudaError_t err = cudaMalloc(&g_device_layer_ptrs, required_size); |
| 113 | + PADDLE_ENFORCE_EQ( |
| 114 | + err, |
| 115 | + cudaSuccess, |
| 116 | + phi::errors::External("cudaMalloc device layer_ptrs failed: %s", |
| 117 | + cudaGetErrorString(err))); |
| 118 | + g_device_layer_ptrs_size = required_size; |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +static bool is_cpu_block_ids_sequential( |
| 123 | + const std::vector<int64_t>& cpu_block_ids) { |
| 124 | + if (cpu_block_ids.empty()) return true; |
| 125 | + int64_t start = cpu_block_ids[0]; |
| 126 | + for (size_t i = 1; i < cpu_block_ids.size(); i++) { |
| 127 | + if (cpu_block_ids[i] != start + static_cast<int64_t>(i)) return false; |
| 128 | + } |
| 129 | + return true; |
| 130 | +} |
19 | 131 |
|
20 | 132 | template <paddle::DataType D> |
21 | | -void SwapCacheImpLayout( |
22 | | - const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu |
23 | | - const int64_t& cache_cpu_pointer, // cpu |
24 | | - const std::vector<int64_t>& cache_shape, |
25 | | - const std::vector<int64_t>& gpu_block_ids, |
26 | | - const std::vector<int64_t>& cpu_block_ids, |
27 | | - int mode) { |
28 | | - /* |
29 | | - mode is 0: gpu to cpu; 1: cpu to gpu |
30 | | -
|
31 | | - cache layout: layer_num * [block_num, head_num, block_size, head_dim] |
32 | | - scale layout: layer_num * [block_num, head_num, block_size] |
33 | | - cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim] |
34 | | - scale buffer layout: [block_num, layer_num, head_num, block_size] |
35 | | - */ |
| 133 | +void SwapCacheImpLayout(const std::vector<paddle::Tensor>& cache_gpu_tensors, |
| 134 | + const int64_t& cache_cpu_pointer, |
| 135 | + const std::vector<int64_t>& cache_shape, |
| 136 | + const std::vector<int64_t>& gpu_block_ids, |
| 137 | + const std::vector<int64_t>& cpu_block_ids, |
| 138 | + int mode) { |
36 | 139 | typedef PDTraits<D> traits_; |
37 | 140 | typedef typename traits_::DataType DataType_; |
38 | 141 | typedef typename traits_::data_t data_t; |
39 | 142 |
|
40 | 143 | const int64_t layer_number = cache_gpu_tensors.size(); |
41 | 144 | int64_t cache_block_stride = 1; |
42 | | - for (int i = 1; i < cache_shape.size(); i++) { |
| 145 | + for (size_t i = 1; i < cache_shape.size(); i++) { |
43 | 146 | cache_block_stride *= cache_shape[i]; |
44 | 147 | } |
45 | 148 |
|
| 149 | + const int n_blocks = gpu_block_ids.size(); |
| 150 | + if (n_blocks == 0) return; |
| 151 | + |
46 | 152 | auto stream = cache_gpu_tensors[0].stream(); |
47 | | - const cudaMemcpyKind copy_kind = |
48 | | - (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; |
49 | | - |
50 | | - for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) { |
51 | | - const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; |
52 | | - data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>()); |
53 | | - auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer); |
54 | | - |
55 | | - for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) { |
56 | | - auto cur_gpu_block_id = gpu_block_ids[block_idx]; |
57 | | - auto cur_cpu_block_id = cpu_block_ids[block_idx]; |
58 | | - auto* cache_gpu_ptr_now = |
59 | | - cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; |
60 | | - auto* cache_cpu_ptr_now = |
61 | | - cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number + |
62 | | - layer_idx * cache_block_stride; |
63 | | - |
64 | | - cudaError_t status = cudaMemcpyAsync( |
65 | | - (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now |
66 | | - : cache_gpu_ptr_now, |
67 | | - (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now |
68 | | - : cache_cpu_ptr_now, |
69 | | - cache_block_stride * sizeof(DataType_), |
70 | | - copy_kind, |
71 | | - stream); |
| 153 | + const size_t block_bytes = cache_block_stride * sizeof(DataType_); |
| 154 | + const size_t total_bytes = (size_t)n_blocks * layer_number * block_bytes; |
| 155 | + |
| 156 | + bool use_optimized = is_cpu_block_ids_sequential(cpu_block_ids); |
| 157 | + |
| 158 | + if (use_optimized) { |
| 159 | + ensure_device_block_ids(n_blocks * sizeof(int64_t)); |
| 160 | + ensure_device_layer_ptrs(layer_number * sizeof(DataType_*)); |
72 | 161 |
|
| 162 | + cudaMemcpyAsync(g_device_block_ids, |
| 163 | + gpu_block_ids.data(), |
| 164 | + n_blocks * sizeof(int64_t), |
| 165 | + cudaMemcpyHostToDevice, |
| 166 | + stream); |
| 167 | + |
| 168 | + std::vector<DataType_*> h_layer_ptrs(layer_number); |
| 169 | + for (int64_t i = 0; i < layer_number; i++) { |
| 170 | + h_layer_ptrs[i] = reinterpret_cast<DataType_*>( |
| 171 | + const_cast<data_t*>(cache_gpu_tensors[i].data<data_t>())); |
| 172 | + } |
| 173 | + cudaMemcpyAsync(g_device_layer_ptrs, |
| 174 | + h_layer_ptrs.data(), |
| 175 | + layer_number * sizeof(DataType_*), |
| 176 | + cudaMemcpyHostToDevice, |
| 177 | + stream); |
| 178 | + |
| 179 | + int64_t cpu_start_block = cpu_block_ids[0]; |
| 180 | + auto* cache_cpu_base = reinterpret_cast<DataType_*>(cache_cpu_pointer) + |
| 181 | + cpu_start_block * layer_number * cache_block_stride; |
| 182 | + |
| 183 | + int grid_size = n_blocks * layer_number; |
| 184 | + |
| 185 | + if (mode == 0) { |
| 186 | + // GPU→CPU: direct kernel write to pinned host memory |
| 187 | + // Multi-layer kernel: each block handles all layers for one swap block |
| 188 | + swap_d2h_kernel<DataType_><<<n_blocks, 512, 0, stream>>>( |
| 189 | + reinterpret_cast<DataType_**>(g_device_layer_ptrs), |
| 190 | + cache_cpu_base, |
| 191 | + reinterpret_cast<int64_t*>(g_device_block_ids), |
| 192 | + n_blocks, |
| 193 | + layer_number, |
| 194 | + cache_block_stride); |
| 195 | + } else { |
| 196 | + // CPU→GPU: DMA memcpy to staging then scatter kernel |
| 197 | + ensure_staging_buffer(total_bytes); |
| 198 | + |
| 199 | + cudaError_t status = cudaMemcpyAsync(g_staging_buffer, |
| 200 | + cache_cpu_base, |
| 201 | + total_bytes, |
| 202 | + cudaMemcpyHostToDevice, |
| 203 | + stream); |
73 | 204 | PADDLE_ENFORCE_EQ(status, |
74 | 205 | cudaSuccess, |
75 | | - phi::errors::External("cudaMemcpyAsync failed: %s", |
| 206 | + phi::errors::External("cudaMemcpyAsync H2D failed: %s", |
76 | 207 | cudaGetErrorString(status))); |
77 | 208 |
|
78 | | -#ifdef SWAP_DEBUG |
79 | | - cudaStreamSynchronize(stream); |
80 | | - std::cout << "mode:" << mode << ", layer_idx:" << layer_idx |
81 | | - << ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:" |
82 | | - << static_cast<float>(*cache_cpu_ptr_now) << std::endl; |
83 | | -#endif |
| 209 | + scatter_blocks_kernel<DataType_><<<grid_size, 256, 0, stream>>>( |
| 210 | + reinterpret_cast<DataType_**>(g_device_layer_ptrs), |
| 211 | + reinterpret_cast<const DataType_*>(g_staging_buffer), |
| 212 | + reinterpret_cast<int64_t*>(g_device_block_ids), |
| 213 | + n_blocks, |
| 214 | + layer_number, |
| 215 | + cache_block_stride); |
| 216 | + } |
| 217 | + } else { |
| 218 | + const cudaMemcpyKind copy_kind = |
| 219 | + (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; |
| 220 | + for (int64_t layer_idx = 0; layer_idx < layer_number; layer_idx++) { |
| 221 | + const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; |
| 222 | + data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>()); |
| 223 | + auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer); |
| 224 | + |
| 225 | + for (int block_idx = 0; block_idx < n_blocks; block_idx++) { |
| 226 | + auto cur_gpu_block_id = gpu_block_ids[block_idx]; |
| 227 | + auto cur_cpu_block_id = cpu_block_ids[block_idx]; |
| 228 | + auto* cache_gpu_ptr_now = |
| 229 | + cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; |
| 230 | + auto* cache_cpu_ptr_now = |
| 231 | + cache_cpu_ptr + |
| 232 | + cur_cpu_block_id * cache_block_stride * layer_number + |
| 233 | + layer_idx * cache_block_stride; |
| 234 | + |
| 235 | + cudaError_t status = cudaMemcpyAsync( |
| 236 | + (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now |
| 237 | + : cache_gpu_ptr_now, |
| 238 | + (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now |
| 239 | + : cache_cpu_ptr_now, |
| 240 | + block_bytes, |
| 241 | + copy_kind, |
| 242 | + stream); |
| 243 | + PADDLE_ENFORCE_EQ(status, |
| 244 | + cudaSuccess, |
| 245 | + phi::errors::External("cudaMemcpyAsync failed: %s", |
| 246 | + cudaGetErrorString(status))); |
| 247 | + } |
84 | 248 | } |
85 | 249 | } |
| 250 | + |
86 | 251 | cudaError_t sync_status = cudaStreamSynchronize(stream); |
87 | 252 | PADDLE_ENFORCE_EQ(sync_status, |
88 | 253 | cudaSuccess, |
89 | 254 | phi::errors::External("cudaStreamSynchronize failed: %s", |
90 | 255 | cudaGetErrorString(sync_status))); |
91 | 256 | } |
92 | 257 |
|
93 | | -void SwapCacheLayout( |
94 | | - const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu |
95 | | - const int64_t& cache_cpu_ptrs, // cpu memory pointer |
96 | | - const std::vector<int64_t>& cache_shape, |
97 | | - const std::vector<int64_t>& gpu_block_ids, |
98 | | - const std::vector<int64_t>& cpu_block_ids, |
99 | | - int rank, |
100 | | - int mode) { |
101 | | - cudaSetDevice(rank); // used for distributed launch |
| 258 | +void SwapCacheLayout(const std::vector<paddle::Tensor>& cache_gpu_tensors, |
| 259 | + const int64_t& cache_cpu_ptrs, |
| 260 | + const std::vector<int64_t>& cache_shape, |
| 261 | + const std::vector<int64_t>& gpu_block_ids, |
| 262 | + const std::vector<int64_t>& cpu_block_ids, |
| 263 | + int rank, |
| 264 | + int mode) { |
| 265 | + cudaSetDevice(rank); |
102 | 266 | assert(cache_gpu_tensors.size() > 0); |
103 | 267 | switch (cache_gpu_tensors[0].dtype()) { |
104 | 268 | case paddle::DataType::BFLOAT16: |
|
0 commit comments