Skip to content

Commit 27b9fd7

Browse files
committed
delete template param and opt perf
1 parent db5c956 commit 27b9fd7

File tree

2 files changed

+105
-165
lines changed

2 files changed

+105
-165
lines changed

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 104 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace cub = hipcub;
3939
// Reduce split or not, Whether to use ReduceHigherDim
4040
#define REDUCE_SPLIT_BOUNDARY 512
4141
#define REDUCE_VEC 4
42+
#define MAX_DIM 10
4243

4344
namespace paddle {
4445
namespace operators {
@@ -126,10 +127,10 @@ static inline void CheckReduceRank(int reduce_rank, int rank) {
126127
template <typename T, size_t ElementCount, typename VectorLikeType>
127128
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
128129
const VectorLikeType& vec) {
129-
PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
130+
PADDLE_ENFORCE_LE(vec.size(), ElementCount,
130131
platform::errors::InvalidArgument(
131132
"Cub reduce Array: size not match. Received "
132-
"vec.size() %d != ElementCount %d.",
133+
"vec.size() %d > ElementCount %d.",
133134
vec.size(), ElementCount));
134135
size_t n = static_cast<size_t>(vec.size());
135136
paddle::framework::Array<T, ElementCount> ret;
@@ -150,6 +151,32 @@ enum ReduceType {
150151
kReduceAny = 0x03, // when reduce_dim.size() > 1
151152
};
152153

154+
struct IndexCalculator {
155+
IndexCalculator(int dim, paddle::framework::Array<int, MAX_DIM> strides,
156+
paddle::framework::Array<int, MAX_DIM> dims,
157+
paddle::framework::Array<FastDivMod, MAX_DIM> divmoders)
158+
: dim(dim), strides(strides), dims(dims), divmoders(divmoders) {}
159+
160+
__device__ inline int Get(int offset) const {
161+
int index = 0;
162+
#pragma unroll
163+
for (int i = 0; i < MAX_DIM; ++i) {
164+
if (i == dim) {
165+
break;
166+
}
167+
auto divmod = divmoders[i].Divmod(offset);
168+
index += (divmod.val[0] * strides[dims[i]]);
169+
offset = divmod.val[1];
170+
}
171+
return index;
172+
}
173+
174+
int dim;
175+
paddle::framework::Array<int, MAX_DIM> strides;
176+
paddle::framework::Array<int, MAX_DIM> dims;
177+
paddle::framework::Array<FastDivMod, MAX_DIM> divmoders;
178+
};
179+
153180
// reduce config
154181
template <typename Ty>
155182
struct ReduceConfig {
@@ -577,47 +604,38 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
577604

578605
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
579606
// function will be used
580-
// blockId.x -> left_num, threadId.x -> reduce_num
581-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
582-
int Rank, int ReduceRank>
583-
__device__ void ReduceAny(
584-
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
585-
int reduce_num, int left_num, int block_size,
586-
paddle::framework::Array<int, Rank> x_strides,
587-
paddle::framework::Array<int, ReduceRank> reduce_dim,
588-
paddle::framework::Array<FastDivMod, ReduceRank> reduce_divmoders,
589-
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
590-
paddle::framework::Array<FastDivMod, Rank - ReduceRank> left_divmoders) {
607+
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
608+
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
609+
TransformOp transformer, Ty init, int reduce_num,
610+
int left_num, int block_size, bool reduce_lastdim,
611+
const IndexCalculator& reduce_index_calculator,
612+
const IndexCalculator& left_index_calculator) {
613+
int input_idx, left_idx, stride;
591614
// the last dim gets involved in reduction
592-
if (reduce_dim[ReduceRank - 1] == Rank - 1) {
593-
int input_idx = blockIdx.y * blockDim.x + threadIdx.x;
594-
int left_idx = blockIdx.x;
595-
int stride = gridDim.y * blockDim.x;
596-
int input_offset = 0;
597-
// calculate the offset, means the addr where each thread really start.
598-
#pragma unroll
599-
for (int i = 0; i < Rank - ReduceRank; ++i) {
600-
auto divmod = left_divmoders[i].Divmod(left_idx);
601-
input_offset += (divmod.val[0] * x_strides[left_dim[i]]);
602-
left_idx = divmod.val[1];
603-
}
604-
const Tx* input = x + input_offset;
605-
Ty reduce_var = init;
615+
if (reduce_lastdim) {
616+
input_idx = blockIdx.y * blockDim.x + threadIdx.x;
617+
left_idx = blockIdx.x;
618+
stride = gridDim.y * blockDim.x;
619+
} else {
620+
input_idx = blockIdx.y * blockDim.y + threadIdx.y;
621+
left_idx = blockIdx.x * blockDim.x + threadIdx.x;
622+
stride = gridDim.y * blockDim.y;
623+
}
624+
// calculate the offset, means the addr where each thread really start.
625+
int input_offset = left_index_calculator.Get(left_idx);
626+
const Tx* input = x + input_offset;
627+
Ty reduce_var = init;
606628

629+
// 1. reduce for each thread
630+
if (left_idx < left_num) {
607631
// load REDUCE_VEC data once, and then compute
608632
Tx inputs[REDUCE_VEC];
609633
int bound = reduce_num - (REDUCE_VEC - 1) * stride;
610634
while (input_idx < bound) {
611635
#pragma unroll
612636
for (int i = 0; i < REDUCE_VEC; ++i) {
613637
int reduce_idx = input_idx + i * stride;
614-
int idx_x = 0;
615-
#pragma unroll
616-
for (int j = 0; j < ReduceRank; ++j) {
617-
auto divmod = reduce_divmoders[j].Divmod(reduce_idx);
618-
idx_x += (divmod.val[0] * x_strides[reduce_dim[j]]);
619-
reduce_idx = divmod.val[1];
620-
}
638+
int idx_x = reduce_index_calculator.Get(reduce_idx);
621639
inputs[i] = input[idx_x];
622640
}
623641
#pragma unroll
@@ -635,13 +653,7 @@ __device__ void ReduceAny(
635653
break;
636654
}
637655
int reduce_idx = input_idx;
638-
int idx_x = 0;
639-
#pragma unroll
640-
for (int j = 0; j < ReduceRank; ++j) {
641-
auto divmod = reduce_divmoders[j].Divmod(reduce_idx);
642-
idx_x += (divmod.val[0] * x_strides[reduce_dim[j]]);
643-
reduce_idx = divmod.val[1];
644-
}
656+
int idx_x = reduce_index_calculator.Get(reduce_idx);
645657
inputs[i] = input[idx_x];
646658
input_idx += stride;
647659
}
@@ -654,47 +666,12 @@ __device__ void ReduceAny(
654666
reduce_var = reducer(reduce_var, transformer(inputs[i]));
655667
input_idx += stride;
656668
}
669+
}
657670

658-
__syncthreads();
659-
660-
reduce_var = BlockReduce(reduce_var, reducer);
661-
662-
if (threadIdx.x == 0) {
663-
y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var;
664-
}
665-
} else {
671+
// 2. reduce in block y
672+
if (blockDim.y > 1) {
666673
// need shared_memory to reduce block y
667674
__shared__ Ty shared_memory[detail::kMaxThread];
668-
int left_idx = blockIdx.x * blockDim.x + threadIdx.x;
669-
int input_idx = (blockIdx.y * blockDim.y + threadIdx.y) * block_size;
670-
int left_idx_tmp = left_idx;
671-
int input_offset = 0;
672-
#pragma unroll
673-
for (int i = 0; i < Rank - ReduceRank; ++i) {
674-
auto divmod = left_divmoders[i].Divmod(left_idx_tmp);
675-
input_offset += (divmod.val[0] * x_strides[left_dim[i]]);
676-
left_idx_tmp = divmod.val[1];
677-
}
678-
const Tx* input = x + input_offset;
679-
Ty reduce_var = init;
680-
681-
if (left_idx < left_num) {
682-
int loop = reduce_num - input_idx;
683-
loop = loop > block_size ? block_size : loop;
684-
for (int i = 0; i < loop; ++i) {
685-
int reduce_idx = i + input_idx;
686-
int idx_x = 0;
687-
#pragma unroll
688-
for (int j = 0; j < ReduceRank; ++j) {
689-
auto divmod = reduce_divmoders[j].Divmod(reduce_idx);
690-
idx_x += (divmod.val[0] * x_strides[reduce_dim[j]]);
691-
reduce_idx = divmod.val[1];
692-
}
693-
reduce_var =
694-
reducer(reduce_var, static_cast<Ty>(transformer(input[idx_x])));
695-
}
696-
}
697-
698675
shared_memory[SharedMemoryIndex(0)] = reduce_var;
699676
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
700677
__syncthreads();
@@ -704,7 +681,16 @@ __device__ void ReduceAny(
704681
}
705682
shared_memory[SharedMemoryIndex(0)] = reduce_var;
706683
}
684+
}
685+
__syncthreads();
707686

687+
if (reduce_lastdim) {
688+
// 3. reduce in block x
689+
reduce_var = BlockReduce(reduce_var, reducer);
690+
if (threadIdx.x == 0) {
691+
y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var;
692+
}
693+
} else {
708694
if (left_idx < left_num && threadIdx.y == 0) {
709695
y[blockIdx.y * left_num + left_idx] = reduce_var;
710696
}
@@ -713,15 +699,13 @@ __device__ void ReduceAny(
713699

714700
// module function designed for global function
715701
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
716-
int Rank, int ReduceRank, int ReduceType>
717-
__device__ void ReduceModule(
718-
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
719-
int reduce_num, int left_num, int blocking_size,
720-
paddle::framework::Array<int, Rank> x_strides,
721-
paddle::framework::Array<int, ReduceRank> reduce_dim,
722-
paddle::framework::Array<FastDivMod, ReduceRank> reduce_divmoders,
723-
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
724-
paddle::framework::Array<FastDivMod, Rank - ReduceRank> left_divmoders) {
702+
int ReduceType>
703+
__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
704+
TransformOp transformer, Ty init, int reduce_num,
705+
int left_num, int blocking_size,
706+
bool reduce_lastdim,
707+
const IndexCalculator& reduce_index_calculator,
708+
const IndexCalculator& left_index_calculator) {
725709
if (ReduceType == ReduceType::kReduceLastDim) {
726710
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
727711
init, reduce_num);
@@ -733,48 +717,52 @@ __device__ void ReduceModule(
733717

734718
// reduce_rank >= 2
735719
} else {
736-
ReduceAny<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>(
720+
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
737721
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
738-
x_strides, reduce_dim, reduce_divmoders, left_dim, left_divmoders);
722+
reduce_lastdim, reduce_index_calculator, left_index_calculator);
739723
}
740724
}
741725

742726
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
743-
int Rank, int ReduceRank, int ReduceType>
744-
__global__ void ReduceKernelFunction(
745-
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
746-
int reduce_num, int left_num, int block_size,
747-
paddle::framework::Array<int, Rank> x_strides,
748-
paddle::framework::Array<int, ReduceRank> reduce_dim,
749-
paddle::framework::Array<FastDivMod, ReduceRank> reduce_divmoders,
750-
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
751-
paddle::framework::Array<FastDivMod, Rank - ReduceRank> left_divmoders) {
752-
ReduceModule<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank, ReduceType>(
753-
x, y, reducer, transformer, init, reduce_num, left_num, block_size,
754-
x_strides, reduce_dim, reduce_divmoders, left_dim, left_divmoders);
727+
int ReduceType>
728+
__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
729+
TransformOp transformer, Ty init,
730+
int reduce_num, int left_num,
731+
int blocking_size, bool reduce_lastdim,
732+
IndexCalculator reduce_index_calculator,
733+
IndexCalculator left_index_calculator) {
734+
ReduceModule<Tx, Ty, ReduceOp, TransformOp, ReduceType>(
735+
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
736+
reduce_lastdim, reduce_index_calculator, left_index_calculator);
755737
}
756738

757-
template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank>
739+
template <typename Tx, typename Ty, typename ReduceOp>
758740
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
759741
const ReduceOp& reducer, Ty init,
760742
gpuStream_t stream, ReduceConfig<Ty> config) {
761743
using TransformOp = typename ReduceOp::Transformer;
762744

745+
int reduce_rank = config.reduce_strides.size();
746+
int left_rank = config.left_strides.size();
747+
auto reduce_index_calculator = IndexCalculator(
748+
reduce_rank, detail::VectorToArray<int, MAX_DIM>(config.x_strides),
749+
detail::VectorToArray<int, MAX_DIM>(config.reduce_dim),
750+
detail::VectorToArray<FastDivMod, MAX_DIM>(config.reduce_divmoders));
751+
auto left_index_calculator = IndexCalculator(
752+
left_rank, detail::VectorToArray<int, MAX_DIM>(config.x_strides),
753+
detail::VectorToArray<int, MAX_DIM>(config.left_dim),
754+
detail::VectorToArray<FastDivMod, MAX_DIM>(config.left_divmoders));
755+
763756
#define CUB_REDUCE_TYPE_CASE(type) \
764757
case type: { \
765758
constexpr auto kReduceType = type; \
766759
ReduceKernelFunction< \
767-
Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank, \
760+
Tx, Ty, ReduceOp, TransformOp, \
768761
kReduceType><<<config.grid, config.block, 0, stream>>>( \
769762
x_data, config.output_data, reducer, TransformOp(config.reduce_num), \
770763
init, config.reduce_num, config.left_num, config.blocking_size, \
771-
detail::VectorToArray<int, Rank>(config.x_strides), \
772-
detail::VectorToArray<int, ReduceRank>(config.reduce_dim), \
773-
detail::VectorToArray<FastDivMod, ReduceRank>( \
774-
config.reduce_divmoders), \
775-
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim), \
776-
detail::VectorToArray<FastDivMod, Rank - ReduceRank>( \
777-
config.left_divmoders)); \
764+
config.reduce_lastdim, reduce_index_calculator, \
765+
left_index_calculator); \
778766
} break
779767

780768
switch (config.reduce_type) {
@@ -795,61 +783,13 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
795783
}
796784

797785
ReduceKernelFunction<
798-
Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank, ReduceRank,
786+
Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>,
799787
ReduceType::kReduceHigherDim><<<grid, block, 0, stream>>>(
800788
config.output_data, y_data, reducer,
801789
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
802-
config.left_num, config.grid.y,
803-
detail::VectorToArray<int, Rank>(config.x_strides),
804-
detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
805-
detail::VectorToArray<FastDivMod, ReduceRank>(config.reduce_divmoders),
806-
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
807-
detail::VectorToArray<FastDivMod, Rank - ReduceRank>(
808-
config.left_divmoders));
809-
}
810-
}
811-
812-
template <typename Tx, typename Ty, typename ReduceOp>
813-
static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
814-
const ReduceOp& reducer, Ty init,
815-
gpuStream_t stream, ReduceConfig<Ty> config) {
816-
int reduce_rank = config.reduce_strides.size();
817-
int rank = config.x_strides.size();
818-
819-
#define CUB_RANK_CASE(i, ...) \
820-
case i: { \
821-
constexpr auto Rank = i; \
822-
switch (reduce_rank) { __VA_ARGS__; } \
823-
} break
824-
825-
#define CUB_REDUCE_RANK_CASE(i, ...) \
826-
case i: { \
827-
constexpr auto ReduceRank = i; \
828-
LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank>( \
829-
x_data, y_data, reducer, init, stream, config); \
830-
} break
831-
832-
detail::CheckReduceRank(reduce_rank, rank);
833-
switch (rank) {
834-
CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););
835-
836-
CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););
837-
838-
CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););
839-
840-
CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););
841-
842-
CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););
843-
844-
CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););
845-
846-
CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););
847-
848-
CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
790+
config.left_num, config.grid.y, config.reduce_lastdim,
791+
reduce_index_calculator, left_index_calculator);
849792
}
850-
851-
#undef CUB_REDUCE_RANK_CASE
852-
#undef CUB_RANK_CASE
853793
}
854794

855795
template <typename Tx, typename Ty,
@@ -899,8 +839,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
899839
return;
900840
}
901841

902-
ReduceKernelImpl<Tx, Ty, ReduceOp<Tx, Ty>>(x_data, y_data, reducer,
903-
reducer.initial(), stream, config);
842+
LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>(
843+
x_data, y_data, reducer, reducer.initial(), stream, config);
904844
}
905845

906846
template <typename Tx, template <typename, typename> class ReduceOp>

paddle/fluid/platform/fast_divmod.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct FastDivMod {
5454
return (t + n) >> shift_val;
5555
}
5656

57-
__device__ __forceinline__ DivModT Divmod(uint32_t n) {
57+
__device__ __forceinline__ DivModT Divmod(uint32_t n) const {
5858
uint32_t q = Div(n);
5959
DivModT result = {q, n - q * divisor};
6060
return result;

0 commit comments

Comments
 (0)