Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions paddle/phi/backends/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,60 @@ CUDA_ATOMIC_WRAPPER(Mul, float) {
return __int_as_float(old);
}

__device__ __forceinline__ uint32_t __loadAligned(const uintptr_t base_addr,
uint32_t mask,
uint32_t shift) {
// get 4B aligned address
uint32_t aligned_value = *reinterpret_cast<const uint32_t *>(base_addr);
return (aligned_value & mask) >> shift;
}

CUDA_ATOMIC_WRAPPER(Mul, uint8_t) {
// get 4D aligned base address
uintptr_t base_addr = reinterpret_cast<uintptr_t>(address) & (~3);
uint32_t offset = reinterpret_cast<uintptr_t>(address) - base_addr;
uint32_t shift = offset * 8;
uint32_t mask = 0xFFU << shift;

uint32_t old32 = __loadAligned(base_addr, mask, shift), assumed32 = 0;

do {
assumed32 = old32;
uint8_t current = static_cast<uint8_t>((old32 & mask) >> shift);
uint8_t new_val = current * val;
uint32_t new32 =
(old32 & ~mask) | (static_cast<uint32_t>(new_val) << shift);

old32 =
atomicCAS(reinterpret_cast<uint32_t *>(base_addr), assumed32, new32);
} while (assumed32 != old32);

return static_cast<uint8_t>((old32 & mask) >> shift);
}

CUDA_ATOMIC_WRAPPER(Mul, int16_t) {
// get 4D aligned base address
uintptr_t base_addr = reinterpret_cast<uintptr_t>(address) & (~3);
uint32_t offset = (reinterpret_cast<uintptr_t>(address) - base_addr) / 2;
uint32_t shift = offset * 16;
uint32_t mask = 0xFFFFU << shift;

uint32_t old32 = __loadAligned(base_addr, mask, shift), assumed32 = 0;

do {
assumed32 = old32;
int16_t current = static_cast<int16_t>((old32 & mask) >> shift);
int16_t new_val = current * val;
uint32_t new32 =
(old32 & ~mask) | (static_cast<uint32_t>(new_val) << shift);

old32 =
atomicCAS(reinterpret_cast<uint32_t *>(base_addr), assumed32, new32);
} while (assumed32 != old32);

return static_cast<int16_t>((old32 & mask) >> shift);
}

CUDA_ATOMIC_WRAPPER(Mul, double) {
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
Expand Down Expand Up @@ -943,6 +997,41 @@ CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) {
}
}

#define DEFINE_ATOMIC_MINMAX(Dtype, OpType, operator) \
__device__ __forceinline__ Dtype CudaAtomic##OpType(Dtype *address, \
const Dtype val) { \
uintptr_t base_addr = reinterpret_cast<uintptr_t>(address) & (~3); \
uint32_t offset_bytes = reinterpret_cast<uintptr_t>(address) - base_addr; \
uint32_t shift = 0, mask = 0; \
if constexpr (sizeof(Dtype) == 1) { \
shift = offset_bytes * 8; \
mask = 0xFFU << shift; \
} else { \
shift = (offset_bytes / 2) * 16; \
mask = 0xFFFFU << shift; \
} \
Dtype current = 0; \
Dtype new_val = 0; \
uint32_t assumed32 = 0, old32 = __loadAligned(base_addr, mask, shift); \
do { \
assumed32 = old32; \
current = static_cast<Dtype>((old32 & mask) >> shift); \
new_val = operator(current, val); \
uint32_t new32 = \
(old32 & ~mask) | (static_cast<uint32_t>(new_val) << shift); \
old32 = atomicCAS( \
reinterpret_cast<uint32_t *>(base_addr), assumed32, new32); \
} while (assumed32 != old32); \
return current; \
}

DEFINE_ATOMIC_MINMAX(int16_t, Min, min)
DEFINE_ATOMIC_MINMAX(int16_t, Max, max)
DEFINE_ATOMIC_MINMAX(uint8_t, Min, min)
DEFINE_ATOMIC_MINMAX(uint8_t, Max, max)

#undef DEFINE_ATOMIC_MINMAX

#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,6 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
float,
double,
int,
int16_t,
uint8_t,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/put_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ PD_REGISTER_KERNEL(put_along_axis,
float,
double,
int,
int16_t,
uint8_t,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
float,
double,
int,
int16_t,
uint8_t,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/take_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ PD_REGISTER_KERNEL(take_along_axis,
float,
double,
int,
int16_t,
uint8_t,
int64_t) {}
36 changes: 4 additions & 32 deletions paddle/phi/kernels/funcs/gather_scatter_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,65 +31,37 @@ static TensorAssign tensor_assign;

class ReduceAdd {
public:
template <
typename tensor_t,
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
phi::CudaAtomicAdd(self_data, *src_data);
}
template <typename tensor_t,
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data += *src_data;
}
};
static ReduceAdd reduce_add;

class ReduceMul {
public:
template <
typename tensor_t,
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
phi::CudaAtomicMul(self_data, *src_data);
}
template <typename tensor_t,
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
}
};
static ReduceMul reduce_mul;

class ReduceMax {
public:
template <
typename tensor_t,
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
phi::CudaAtomicMax(self_data, *src_data);
}
template <typename tensor_t,
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data = *src_data > *self_data ? *src_data : *self_data;
}
};
static ReduceMax reduce_max;

class ReduceMin {
public:
template <
typename tensor_t,
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
phi::CudaAtomicMin(self_data, *src_data);
}
template <typename tensor_t,
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data = *src_data < *self_data ? *src_data : *self_data;
}
};
static ReduceMin reduce_min;

Expand Down
27 changes: 15 additions & 12 deletions paddle/phi/kernels/funcs/gather_scatter_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace funcs {
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, \
phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t(func, unsigned char)
Instantiate_Template_Function_index_t(func, unsigned char) \
Instantiate_Template_Function_index_t(func, int16_t)

#define Instantiate_Template_Function_index_t(func, tensor_t) \
template void func<tensor_t, int>(phi::DenseTensor input, \
Expand All @@ -45,17 +46,19 @@ namespace funcs {
bool include_self, \
const phi::DeviceContext& dev_ctx);

#define Instantiate_Template_Function_With_Out(func) \
Instantiate_Template_Function_index_t_With_Out(func, int) \
Instantiate_Template_Function_index_t_With_Out(func, float) \
Instantiate_Template_Function_index_t_With_Out(func, double) \
Instantiate_Template_Function_index_t_With_Out(func, int64_t) \
Instantiate_Template_Function_index_t_With_Out( \
func, phi::dtype::float16) \
Instantiate_Template_Function_index_t_With_Out( \
func, phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t_With_Out( \
func, unsigned char)
#define Instantiate_Template_Function_With_Out(func) \
Instantiate_Template_Function_index_t_With_Out(func, int) \
Instantiate_Template_Function_index_t_With_Out(func, float) \
Instantiate_Template_Function_index_t_With_Out(func, double) \
Instantiate_Template_Function_index_t_With_Out(func, int64_t) \
Instantiate_Template_Function_index_t_With_Out( \
func, phi::dtype::float16) \
Instantiate_Template_Function_index_t_With_Out( \
func, phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t_With_Out( \
func, unsigned char) \
Instantiate_Template_Function_index_t_With_Out( \
func, int16_t)
#define Instantiate_Template_Function_index_t_With_Out(func, tensor_t) \
template void func<tensor_t, int>(phi::DenseTensor input, \
int dim, \
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,7 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
double,
int64_t,
int,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/put_along_axis_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ PD_REGISTER_KERNEL(put_along_axis,
float,
double,
int64_t,
uint8_t,
int16_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,7 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
double,
int64_t,
int,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/take_along_axis_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(take_along_axis,
double,
int64_t,
int,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
Loading
Loading