Skip to content

Commit 20da770

Browse files
authored
[HIP] 解决hipMemcpy无法overlap的问题,修改后AMD GPU性能提升大于10% (#33982)
1 parent 758dd7b commit 20da770

File tree

1 file changed

+83
-22
lines changed

1 file changed

+83
-22
lines changed

paddle/fluid/operators/math/concat_and_split.cu

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include <algorithm>
1616
#include <vector>
17+
#include "gflags/gflags.h"
1718
#include "paddle/fluid/framework/mixed_vector.h"
1819
#include "paddle/fluid/memory/malloc.h"
1920
#include "paddle/fluid/operators/math/concat_and_split.h"
@@ -242,8 +243,28 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
242243
int in_col = input[0].numel() / in_row;
243244
int out_row = in_row, out_col = 0;
244245

245-
std::vector<const T*> inputs_data(in_num);
246-
std::vector<int> inputs_col(in_num + 1);
246+
int inputs_col_num = in_num + 1;
247+
std::vector<const T*> inputs_data_vec(in_num);
248+
std::vector<int> inputs_col_vec(inputs_col_num);
249+
const T** inputs_data = inputs_data_vec.data();
250+
int* inputs_col = inputs_col_vec.data();
251+
252+
// There are some differences between hip runtime and NV runtime.
253+
// In NV, when the pageable memory data less than 64K is transferred from
254+
// hosttodevice, it will be automatically asynchronous.
255+
// However, only pinned memory in hip can copy asynchronously
256+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
257+
// 3.2.6.1. Concurrent Execution between Host and Device
258+
// Memory copies from host to device of a memory block of 64 KB or less
259+
#ifdef PADDLE_WITH_HIP
260+
memory::AllocationPtr data_alloc, col_alloc;
261+
data_alloc =
262+
memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*));
263+
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
264+
col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
265+
inputs_col_num * sizeof(int));
266+
inputs_col = reinterpret_cast<int*>(col_alloc->ptr());
267+
#endif
247268

248269
inputs_col[0] = 0;
249270
bool has_same_shape = true;
@@ -264,12 +285,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
264285
memory::allocation::AllocationPtr tmp_dev_ins_data;
265286
const T** dev_ins_data = nullptr;
266287
if (!has_same_shape || in_num < 2 || in_num > 4) {
267-
tmp_dev_ins_data =
268-
memory::Alloc(context, inputs_data.size() * sizeof(T*));
288+
tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
269289
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
270290
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
271-
static_cast<void*>(inputs_data.data()),
272-
inputs_data.size() * sizeof(T*), context.stream());
291+
static_cast<void*>(inputs_data), in_num * sizeof(T*),
292+
context.stream());
273293
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
274294
}
275295

@@ -292,17 +312,29 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
292312
}
293313
} else {
294314
auto tmp_dev_ins_col_data =
295-
memory::Alloc(context, inputs_col.size() * sizeof(int));
315+
memory::Alloc(context, inputs_col_num * sizeof(int));
296316
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
297317
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
298-
static_cast<void*>(inputs_col.data()),
299-
inputs_col.size() * sizeof(int), context.stream());
318+
static_cast<void*>(inputs_col), inputs_col_num * sizeof(int),
319+
context.stream());
300320
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());
301321

302322
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
303-
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
323+
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
304324
out_row, out_col, output->data<T>());
305325
}
326+
#ifdef PADDLE_WITH_HIP
327+
// Prevent the pinned memory value from being covered and release the memory
328+
// after the launch kernel of the stream is executed (reapply pinned memory
329+
// next time)
330+
auto* data_alloc_released = data_alloc.release();
331+
auto* col_alloc_released = col_alloc.release();
332+
context.AddStreamCallback([data_alloc_released, col_alloc_released] {
333+
memory::allocation::AllocationDeleter deleter;
334+
deleter(data_alloc_released);
335+
deleter(col_alloc_released);
336+
});
337+
#endif
306338
}
307339
};
308340

@@ -313,6 +345,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
313345
template <typename T>
314346
class SplitFunctor<platform::CUDADeviceContext, T> {
315347
public:
348+
SplitFunctor();
316349
void operator()(const platform::CUDADeviceContext& context,
317350
const framework::Tensor& input,
318351
const std::vector<const framework::Tensor*>& ref_inputs,
@@ -329,8 +362,27 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
329362
int64_t in_col = 0, in_row = out_row;
330363
bool has_same_shape = true;
331364

332-
std::vector<T*> outputs_data(o_num);
333-
std::vector<int64_t> outputs_cols(o_num + 1);
365+
int outputs_cols_num = o_num + 1;
366+
std::vector<T*> outputs_data_vec(o_num);
367+
std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
368+
T** outputs_data = outputs_data_vec.data();
369+
int64_t* outputs_cols = outputs_cols_vec.data();
370+
371+
// There are some differences between hip runtime and NV runtime.
372+
// In NV, when the pageable memory data less than 64K is transferred from
373+
// hosttodevice, it will be automatically asynchronous.
374+
// However, only pinned memory in hip can copy asynchronously
375+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
376+
// 3.2.6.1. Concurrent Execution between Host and Device
377+
// Memory copies from host to device of a memory block of 64 KB or less
378+
#ifdef PADDLE_WITH_HIP
379+
memory::AllocationPtr data_alloc, cols_alloc;
380+
data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*));
381+
outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
382+
cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
383+
(outputs_cols_num) * sizeof(int64_t));
384+
outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
385+
#endif
334386

335387
outputs_cols[0] = 0;
336388
for (int i = 0; i < o_num; ++i) {
@@ -354,12 +406,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
354406
memory::allocation::AllocationPtr tmp_dev_outs_data;
355407
T** dev_out_gpu_data = nullptr;
356408
if (!has_same_shape || o_num < 2 || o_num > 4) {
357-
tmp_dev_outs_data =
358-
memory::Alloc(context, outputs_data.size() * sizeof(T*));
409+
tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
359410
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
360411
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
361-
reinterpret_cast<void*>(outputs_data.data()),
362-
outputs_data.size() * sizeof(T*), context.stream());
412+
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
413+
context.stream());
363414
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
364415
}
365416

@@ -382,20 +433,30 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
382433
}
383434
} else {
384435
auto tmp_dev_ins_col_data =
385-
memory::Alloc(context,
386-
387-
outputs_cols.size() * sizeof(int64_t));
436+
memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
388437
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
389438
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
390-
reinterpret_cast<void*>(outputs_cols.data()),
391-
outputs_cols.size() * sizeof(int64_t), context.stream());
439+
reinterpret_cast<void*>(outputs_cols),
440+
outputs_cols_num * sizeof(int64_t), context.stream());
392441
int64_t* dev_outs_col_data =
393442
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
394443

395444
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
396445
input.data<T>(), in_row, in_col, dev_outs_col_data,
397-
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
446+
static_cast<int>(outputs_cols_num), dev_out_gpu_data);
398447
}
448+
#ifdef PADDLE_WITH_HIP
449+
// Prevent the pinned memory value from being covered and release the memory
450+
// after the launch kernel of the stream is executed (reapply pinned memory
451+
// next time)
452+
auto* data_alloc_released = data_alloc.release();
453+
auto* cols_alloc_released = cols_alloc.release();
454+
context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
455+
memory::allocation::AllocationDeleter deleter;
456+
deleter(data_alloc_released);
457+
deleter(cols_alloc_released);
458+
});
459+
#endif
399460
}
400461
};
401462

0 commit comments

Comments
 (0)