Skip to content

Commit c1456af

Browse files
EnigmatismsLuckycheng222
authored andcommitted
[PHI] Add uint8/int16 CUDA atomic mul/min/max and upgraded take/put_along_axis (input types) (PaddlePaddle#74693)
* [PHI] Aligned uint8 and int16 atomic funcs * [PHI] Removed some of the GPU only constraints. * [PHI] Fixed put_along_axis CPU end test error
1 parent f087268 commit c1456af

13 files changed

+329
-64
lines changed

paddle/phi/backends/gpu/gpu_primitives.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,60 @@ CUDA_ATOMIC_WRAPPER(Mul, float) {
457457
return __int_as_float(old);
458458
}
459459

460+
__device__ __forceinline__ uint32_t __loadAligned(const uintptr_t base_addr,
461+
uint32_t mask,
462+
uint32_t shift) {
463+
// get 4B aligned address
464+
uint32_t aligned_value = *reinterpret_cast<const uint32_t *>(base_addr);
465+
return (aligned_value & mask) >> shift;
466+
}
467+
468+
CUDA_ATOMIC_WRAPPER(Mul, uint8_t) {
469+
// get 4D aligned base address
470+
uintptr_t base_addr = reinterpret_cast<uintptr_t>(address) & (~3);
471+
uint32_t offset = reinterpret_cast<uintptr_t>(address) - base_addr;
472+
uint32_t shift = offset * 8;
473+
uint32_t mask = 0xFFU << shift;
474+
475+
uint32_t old32 = __loadAligned(base_addr, mask, shift), assumed32 = 0;
476+
477+
do {
478+
assumed32 = old32;
479+
uint8_t current = static_cast<uint8_t>((old32 & mask) >> shift);
480+
uint8_t new_val = current * val;
481+
uint32_t new32 =
482+
(old32 & ~mask) | (static_cast<uint32_t>(new_val) << shift);
483+
484+
old32 =
485+
atomicCAS(reinterpret_cast<uint32_t *>(base_addr), assumed32, new32);
486+
} while (assumed32 != old32);
487+
488+
return static_cast<uint8_t>((old32 & mask) >> shift);
489+
}
490+
491+
CUDA_ATOMIC_WRAPPER(Mul, int16_t) {
492+
// get 4D aligned base address
493+
uintptr_t base_addr = reinterpret_cast<uintptr_t>(address) & (~3);
494+
uint32_t offset = (reinterpret_cast<uintptr_t>(address) - base_addr) / 2;
495+
uint32_t shift = offset * 16;
496+
uint32_t mask = 0xFFFFU << shift;
497+
498+
uint32_t old32 = __loadAligned(base_addr, mask, shift), assumed32 = 0;
499+
500+
do {
501+
assumed32 = old32;
502+
int16_t current = static_cast<int16_t>((old32 & mask) >> shift);
503+
int16_t new_val = current * val;
504+
uint32_t new32 =
505+
(old32 & ~mask) | (static_cast<uint32_t>(new_val) << shift);
506+
507+
old32 =
508+
atomicCAS(reinterpret_cast<uint32_t *>(base_addr), assumed32, new32);
509+
} while (assumed32 != old32);
510+
511+
return static_cast<int16_t>((old32 & mask) >> shift);
512+
}
513+
460514
CUDA_ATOMIC_WRAPPER(Mul, double) {
461515
unsigned long long int *const address_as_ull = // NOLINT
462516
reinterpret_cast<unsigned long long int *>(address); // NOLINT
@@ -943,6 +997,41 @@ CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) {
943997
}
944998
}
945999

1000+
#define DEFINE_ATOMIC_MINMAX(Dtype, OpType, operator) \
1001+
__device__ __forceinline__ Dtype CudaAtomic##OpType(Dtype *address, \
1002+
const Dtype val) { \
1003+
uintptr_t base_addr = reinterpret_cast<uintptr_t>(address) & (~3); \
1004+
uint32_t offset_bytes = reinterpret_cast<uintptr_t>(address) - base_addr; \
1005+
uint32_t shift = 0, mask = 0; \
1006+
if constexpr (sizeof(Dtype) == 1) { \
1007+
shift = offset_bytes * 8; \
1008+
mask = 0xFFU << shift; \
1009+
} else { \
1010+
shift = (offset_bytes / 2) * 16; \
1011+
mask = 0xFFFFU << shift; \
1012+
} \
1013+
Dtype current = 0; \
1014+
Dtype new_val = 0; \
1015+
uint32_t assumed32 = 0, old32 = __loadAligned(base_addr, mask, shift); \
1016+
do { \
1017+
assumed32 = old32; \
1018+
current = static_cast<Dtype>((old32 & mask) >> shift); \
1019+
new_val = operator(current, val); \
1020+
uint32_t new32 = \
1021+
(old32 & ~mask) | (static_cast<uint32_t>(new_val) << shift); \
1022+
old32 = atomicCAS( \
1023+
reinterpret_cast<uint32_t *>(base_addr), assumed32, new32); \
1024+
} while (assumed32 != old32); \
1025+
return current; \
1026+
}
1027+
1028+
DEFINE_ATOMIC_MINMAX(int16_t, Min, min)
1029+
DEFINE_ATOMIC_MINMAX(int16_t, Max, max)
1030+
DEFINE_ATOMIC_MINMAX(uint8_t, Min, min)
1031+
DEFINE_ATOMIC_MINMAX(uint8_t, Max, max)
1032+
1033+
#undef DEFINE_ATOMIC_MINMAX
1034+
9461035
#ifdef PADDLE_WITH_CUDA
9471036
/*
9481037
* One thead block deals with elementwise atomicAdd for vector of len.

paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,5 +180,6 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
180180
float,
181181
double,
182182
int,
183+
int16_t,
183184
uint8_t,
184185
int64_t) {}

paddle/phi/kernels/cpu/put_along_axis_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,6 @@ PD_REGISTER_KERNEL(put_along_axis,
103103
float,
104104
double,
105105
int,
106+
int16_t,
106107
uint8_t,
107108
int64_t) {}

paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,6 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
6666
float,
6767
double,
6868
int,
69+
int16_t,
6970
uint8_t,
7071
int64_t) {}

paddle/phi/kernels/cpu/take_along_axis_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,6 @@ PD_REGISTER_KERNEL(take_along_axis,
6565
float,
6666
double,
6767
int,
68+
int16_t,
6869
uint8_t,
6970
int64_t) {}

paddle/phi/kernels/funcs/gather_scatter_functor.cu

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,65 +31,37 @@ static TensorAssign tensor_assign;
3131

3232
class ReduceAdd {
3333
public:
34-
template <
35-
typename tensor_t,
36-
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
34+
template <typename tensor_t>
3735
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
3836
phi::CudaAtomicAdd(self_data, *src_data);
3937
}
40-
template <typename tensor_t,
41-
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
42-
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
43-
*self_data += *src_data;
44-
}
4538
};
4639
static ReduceAdd reduce_add;
4740

4841
class ReduceMul {
4942
public:
50-
template <
51-
typename tensor_t,
52-
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
43+
template <typename tensor_t>
5344
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
5445
phi::CudaAtomicMul(self_data, *src_data);
5546
}
56-
template <typename tensor_t,
57-
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
58-
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
59-
*self_data *= *src_data;
60-
}
6147
};
6248
static ReduceMul reduce_mul;
6349

6450
class ReduceMax {
6551
public:
66-
template <
67-
typename tensor_t,
68-
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
52+
template <typename tensor_t>
6953
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
7054
phi::CudaAtomicMax(self_data, *src_data);
7155
}
72-
template <typename tensor_t,
73-
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
74-
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
75-
*self_data = *src_data > *self_data ? *src_data : *self_data;
76-
}
7756
};
7857
static ReduceMax reduce_max;
7958

8059
class ReduceMin {
8160
public:
82-
template <
83-
typename tensor_t,
84-
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
61+
template <typename tensor_t>
8562
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
8663
phi::CudaAtomicMin(self_data, *src_data);
8764
}
88-
template <typename tensor_t,
89-
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
90-
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
91-
*self_data = *src_data < *self_data ? *src_data : *self_data;
92-
}
9365
};
9466
static ReduceMin reduce_min;
9567

paddle/phi/kernels/funcs/gather_scatter_functor.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace funcs {
2929
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
3030
Instantiate_Template_Function_index_t(func, \
3131
phi::dtype::bfloat16) \
32-
Instantiate_Template_Function_index_t(func, unsigned char)
32+
Instantiate_Template_Function_index_t(func, unsigned char) \
33+
Instantiate_Template_Function_index_t(func, int16_t)
3334

3435
#define Instantiate_Template_Function_index_t(func, tensor_t) \
3536
template void func<tensor_t, int>(phi::DenseTensor input, \
@@ -45,17 +46,19 @@ namespace funcs {
4546
bool include_self, \
4647
const phi::DeviceContext& dev_ctx);
4748

48-
#define Instantiate_Template_Function_With_Out(func) \
49-
Instantiate_Template_Function_index_t_With_Out(func, int) \
50-
Instantiate_Template_Function_index_t_With_Out(func, float) \
51-
Instantiate_Template_Function_index_t_With_Out(func, double) \
52-
Instantiate_Template_Function_index_t_With_Out(func, int64_t) \
53-
Instantiate_Template_Function_index_t_With_Out( \
54-
func, phi::dtype::float16) \
55-
Instantiate_Template_Function_index_t_With_Out( \
56-
func, phi::dtype::bfloat16) \
57-
Instantiate_Template_Function_index_t_With_Out( \
58-
func, unsigned char)
49+
#define Instantiate_Template_Function_With_Out(func) \
50+
Instantiate_Template_Function_index_t_With_Out(func, int) \
51+
Instantiate_Template_Function_index_t_With_Out(func, float) \
52+
Instantiate_Template_Function_index_t_With_Out(func, double) \
53+
Instantiate_Template_Function_index_t_With_Out(func, int64_t) \
54+
Instantiate_Template_Function_index_t_With_Out( \
55+
func, phi::dtype::float16) \
56+
Instantiate_Template_Function_index_t_With_Out( \
57+
func, phi::dtype::bfloat16) \
58+
Instantiate_Template_Function_index_t_With_Out( \
59+
func, unsigned char) \
60+
Instantiate_Template_Function_index_t_With_Out( \
61+
func, int16_t)
5962
#define Instantiate_Template_Function_index_t_With_Out(func, tensor_t) \
6063
template void func<tensor_t, int>(phi::DenseTensor input, \
6164
int dim, \

paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,7 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
179179
double,
180180
int64_t,
181181
int,
182+
int16_t,
183+
uint8_t,
182184
phi::dtype::float16,
183185
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/put_along_axis_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ PD_REGISTER_KERNEL(put_along_axis,
102102
float,
103103
double,
104104
int64_t,
105+
uint8_t,
106+
int16_t,
105107
int,
106108
phi::dtype::float16,
107109
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,7 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
7373
double,
7474
int64_t,
7575
int,
76+
int16_t,
77+
uint8_t,
7678
phi::dtype::float16,
7779
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)