Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7d58b91
Merge pull request #1 from PaddlePaddle/develop
AnnaTrainingG Mar 25, 2021
1021e08
Merge pull request #2 from PaddlePaddle/develop
AnnaTrainingG Mar 29, 2021
43f53fe
Merge pull request #3 from PaddlePaddle/develop
AnnaTrainingG Apr 19, 2021
d25ab26
Merge pull request #4 from PaddlePaddle/develop
AnnaTrainingG May 7, 2021
8c8717f
Merge pull request #5 from PaddlePaddle/develop
AnnaTrainingG May 25, 2021
9ddf5e8
Merge pull request #6 from PaddlePaddle/develop
AnnaTrainingG May 26, 2021
b0cbcca
Merge pull request #9 from PaddlePaddle/develop
AnnaTrainingG Jun 1, 2021
cdecaf0
Merge pull request #14 from PaddlePaddle/develop
AnnaTrainingG Jun 11, 2021
0da14c9
Merge pull request #16 from PaddlePaddle/develop
AnnaTrainingG Jun 15, 2021
ca95763
Merge pull request #17 from PaddlePaddle/develop
AnnaTrainingG Jun 22, 2021
25ba21c
Merge pull request #18 from PaddlePaddle/develop
AnnaTrainingG Jul 5, 2021
3ce9983
Merge pull request #19 from PaddlePaddle/develop
AnnaTrainingG Jul 6, 2021
61842ed
Merge pull request #20 from PaddlePaddle/develop
AnnaTrainingG Jul 12, 2021
0e2c73b
Merge pull request #21 from PaddlePaddle/develop
AnnaTrainingG Jul 28, 2021
c1e59cf
Merge pull request #22 from PaddlePaddle/develop
AnnaTrainingG Aug 2, 2021
3a54149
Merge pull request #23 from PaddlePaddle/develop
AnnaTrainingG Aug 4, 2021
7addd79
Merge pull request #24 from PaddlePaddle/develop
AnnaTrainingG Aug 11, 2021
1e843d1
Merge pull request #25 from PaddlePaddle/develop
AnnaTrainingG Aug 23, 2021
2783c76
commit for pool higher preformance
AnnaTrainingG Aug 31, 2021
73d13a3
update ReduceMode
AnnaTrainingG Aug 31, 2021
b0e3fdb
update ReduceMode
AnnaTrainingG Aug 31, 2021
e1a92d6
Merge pull request #26 from PaddlePaddle/develop
AnnaTrainingG Sep 1, 2021
9349c38
Merge branch 'develop' of https://github.com/niuliling123/Paddle into…
AnnaTrainingG Sep 2, 2021
2c32248
Add API comments and specify variable names
AnnaTrainingG Sep 2, 2021
05da032
Merge pull request #27 from PaddlePaddle/develop
AnnaTrainingG Sep 3, 2021
840a652
update
AnnaTrainingG Sep 4, 2021
fc36b45
update
AnnaTrainingG Sep 4, 2021
e1fe6dc
Merge pull request #28 from PaddlePaddle/develop
AnnaTrainingG Sep 6, 2021
d9b9f42
Merge branch 'develop' of https://github.com/niuliling123/Paddle into…
AnnaTrainingG Sep 6, 2021
51f2f77
update detail to details
AnnaTrainingG Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,

int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
*blocking_size = detail::GetLastPow2(reduce_num / num_block);
*blocking_size = details::GetLastPow2(reduce_num / num_block);
if (*blocking_size <= 1) {
*blocking_size = detail::GetLastPow2(sqrt(reduce_num));
*blocking_size = details::GetLastPow2(sqrt(reduce_num));
} else if (*blocking_size * 2 < reduce_num) {
*blocking_size *= 2;
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/kernel_primitives/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ namespace kernel_primitives {
namespace details {

#ifdef __HIPCC__
constexpr int kMaxThread = 256;
constexpr int kReduceMaxThread = 256;
constexpr int kWarpSize = 64;
#else
constexpr int kMaxThread = 128;
constexpr int kReduceMaxThread = 128;
constexpr int kWarpSize = 32;
#endif

// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum ReduceMode { kGlobalMode, kLocalMode };

template <typename T>
Expand Down Expand Up @@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[details::kMaxThread];
__shared__ T shared_memory[details::kReduceMaxThread];
shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads();
Expand Down
121 changes: 61 additions & 60 deletions paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,36 +124,36 @@ struct BroadcastConfig {
template <typename Tx, typename Ty, int NX, int NY, int BlockSize>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int stride_nx, int stride_ny) {
int thread_offset = threadIdx.x * NX;

if (NY == 1 && NX == 1) {
dst[0] = static_cast<Ty>(src[threadIdx.x]);
dst[0] = static_cast<Ty>(src[thread_offset]);
} else if (NX == 1) {
int dx = threadIdx.x;
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]);
dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
}
} else if (NY == 1) {
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = static_cast<Ty>(src[idx * stride_nx]);
dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
}
} else {
int dx = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
dst[idy * NX + idx] =
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]);
dst[idy * NX + idx] = static_cast<Ty>(
src[thread_offset + idx * stride_nx + idy * stride_ny]);
}
}
}
}

/**
* @brief load data from src to dst, src can be 1D data or 2D data. When
* boundary judgment is required, you need to set a to true, and a is false by
* default.
* @brief load data from src to dst with stride, src can be 1D data or 2D data.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
Expand All @@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int size_nx, int size_ny,
int stride_nx, int stride_ny) {
int dx = threadIdx.x * NX;
int size = size_nx - dx;
int thread_offset = threadIdx.x * NX;
int left_size_nx = size_nx - thread_offset;

// Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
if (IsBoundary) {
if (dx < size_nx) {
dst[0] = static_cast<Ty>(src[dx]);
if (left_size_nx > 0) {
dst[0] = static_cast<Ty>(src[thread_offset]);
}
} else {
dst[0] = static_cast<Ty>(src[dx]);
dst[0] = static_cast<Ty>(src[thread_offset]);
}
} else if (NX == 1) { // for NX == 1 and NY != 1
#pragma unroll
Expand All @@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break;
}
}
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]);
dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
}
} else if (NY == 1) { // for NY == 1 and NX != 1
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) {
if (idx >= size) {
if (idx >= left_size_nx) {
break;
}
}
dst[idx] = static_cast<Ty>(src[idx * stride_nx + dx]);
dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
}
} else { // for NX != 1 and NY != 1
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) {
if (idx >= size) {
if (idx >= left_size_nx) {
break;
}
}
Expand All @@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break;
}
}
dst[idy * NX + idx] =
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]);
dst[idy * NX + idx] = static_cast<Ty>(
src[thread_offset + idx * stride_nx + idy * stride_ny]);
}
}
}
Expand Down Expand Up @@ -251,25 +251,25 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
int num) {
if (IsBoundary) { // blockDim.x * NX > num
int dx = threadIdx.x * NX;
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + dx < num) {
dst[idx] = src[idx + dx];
if (idx + thread_offset < num) {
dst[idx] = src[thread_offset + idx];
}
}
} else { // blockDim,x * NX < num
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize;
int tid = threadIdx.x * kVectorsPerThread;
int thread_offset = threadIdx.x * kVectorsPerThread;

using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
VecType vec_temp[kVectorsPerThread];

#pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[i + tid];
vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
Expand All @@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
* is 2
* IsBoundary: whether to make boundary judgment
* @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX;
* block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host;
* num: the num of out
* total_num_output: total num of output
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst, const T* __restrict__ src, uint32_t fix,
details::BroadcastConfig<ShapeSize> config, int num, int stride_nx,
int stride_ny) {
uint32_t base_offset = fix + threadIdx.x * NX;
uint32_t offset = 0;
T* dst, const T* __restrict__ src, uint32_t block_offset,
details::BroadcastConfig<ShapeSize> config, int total_num_output,
int stride_nx, int stride_ny) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;

#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx;
uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
index_src = 0;
if (IsBoundary) {
if (idx >= num) {
if (index_output >= total_num_output) {
break;
}
}
offset = 0;
#pragma unroll
for (int i = 0; i < ShapeSize; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * config.strides[i];
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
dst[nx + ny * NX] = src[offset];
dst[nx + ny * NX] = src[index_src];
}
}
}
Expand All @@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc(
* IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment
* @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX;
* block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in
* host;
* size_nx: number of columns to be processed by the current block
Expand All @@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc(
template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(
T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal,
int size_nx, int size_ny, int stride_nx, int stride_ny,
bool reduce_last_dim) {
int base_offset = fix;
T* dst, const T* __restrict__ src, int block_offset,
const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
int stride_ny, bool reduce_last_dim) {
int thread_offset = 0;
if (reduce_last_dim) {
base_offset += threadIdx.x;
thread_offset = block_offset + threadIdx.x;
} else {
base_offset += threadIdx.y;
thread_offset = block_offset + threadIdx.y;
}

if (NX == 1) {
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) {
if (base_offset >= size_ny) {
if (thread_offset >= size_ny) {
break;
}
}
uint32_t offset = index_cal(base_offset);
dst[ny] = src[offset];
base_offset += stride_ny;
uint32_t index_src = index_cal(thread_offset);
dst[ny] = src[index_src];
thread_offset += stride_ny;
}
} else {
#pragma unroll
Expand All @@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce(
break;
}
}
uint32_t offset = index_cal(base_offset);
dst[nx + ny * NX] = src[offset];
base_offset += stride_ny;
uint32_t index_src = index_cal(thread_offset);
dst[nx + ny * NX] = src[index_src];
thread_offset += stride_ny;
}
thread_offset += stride_nx;
}
}
}

/** @brief: WriteData
/**
* @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
Expand All @@ -412,26 +413,26 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
int num) {
if (IsBoundary) {
int dx = threadIdx.x * NX;
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((idx + dx) < num) {
dst[idx + dx] = src[idx];
if ((thread_offset + idx) < num) {
dst[thread_offset + idx] = src[idx];
}
}
} else {
// Vector type
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize;

int dx = threadIdx.x * kVectorsPerThread;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[dx + idx] = vec_temp[idx];
vec_dst[thread_offset + idx] = vec_temp[idx];
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor(
static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); }
static __device__ __forceinline__ double LogFunctor(double x) { return log(x); }

} // namespace details
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template <typename Tx, typename Ty = Tx>
Expand Down Expand Up @@ -75,6 +74,7 @@ struct DivideFunctor {
T n_inv;
};

} // namespace details
} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
Loading