@@ -39,6 +39,7 @@ namespace cub = hipcub;
3939// Reduce split or not, Whether to use ReduceHigherDim
4040#define REDUCE_SPLIT_BOUNDARY 512
4141#define REDUCE_VEC 4
42+ #define MAX_DIM 10
4243
4344namespace paddle {
4445namespace operators {
@@ -126,10 +127,10 @@ static inline void CheckReduceRank(int reduce_rank, int rank) {
126127template <typename T, size_t ElementCount, typename VectorLikeType>
127128static inline paddle::framework::Array<T, ElementCount> VectorToArray (
128129 const VectorLikeType& vec) {
129- PADDLE_ENFORCE_EQ (vec.size (), ElementCount,
130+ PADDLE_ENFORCE_LE (vec.size (), ElementCount,
130131 platform::errors::InvalidArgument (
131132 " Cub reduce Array: size not match. Received "
132- " vec.size() %d != ElementCount %d." ,
133+ " vec.size() %d > ElementCount %d." ,
133134 vec.size (), ElementCount));
134135 size_t n = static_cast <size_t >(vec.size ());
135136 paddle::framework::Array<T, ElementCount> ret;
@@ -150,6 +151,32 @@ enum ReduceType {
150151 kReduceAny = 0x03 , // when reduce_dim.size() > 1
151152};
152153
154+ struct IndexCalculator {
155+ IndexCalculator (int dim, paddle::framework::Array<int , MAX_DIM> strides,
156+ paddle::framework::Array<int , MAX_DIM> dims,
157+ paddle::framework::Array<FastDivMod, MAX_DIM> divmoders)
158+ : dim(dim), strides(strides), dims(dims), divmoders(divmoders) {}
159+
160+ __device__ inline int Get (int offset) const {
161+ int index = 0 ;
162+ #pragma unroll
163+ for (int i = 0 ; i < MAX_DIM; ++i) {
164+ if (i == dim) {
165+ break ;
166+ }
167+ auto divmod = divmoders[i].Divmod (offset);
168+ index += (divmod.val [0 ] * strides[dims[i]]);
169+ offset = divmod.val [1 ];
170+ }
171+ return index;
172+ }
173+
174+ int dim;
175+ paddle::framework::Array<int , MAX_DIM> strides;
176+ paddle::framework::Array<int , MAX_DIM> dims;
177+ paddle::framework::Array<FastDivMod, MAX_DIM> divmoders;
178+ };
179+
153180// reduce config
154181template <typename Ty>
155182struct ReduceConfig {
@@ -577,47 +604,38 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
577604
578605// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
579606// function will be used
580- // blockId.x -> left_num, threadId.x -> reduce_num
581- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
582- int Rank, int ReduceRank>
583- __device__ void ReduceAny (
584- const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
585- int reduce_num, int left_num, int block_size,
586- paddle::framework::Array<int , Rank> x_strides,
587- paddle::framework::Array<int , ReduceRank> reduce_dim,
588- paddle::framework::Array<FastDivMod, ReduceRank> reduce_divmoders,
589- paddle::framework::Array<int , Rank - ReduceRank> left_dim,
590- paddle::framework::Array<FastDivMod, Rank - ReduceRank> left_divmoders) {
607+ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
608+ __device__ void ReduceAny (const Tx* x, Ty* y, ReduceOp reducer,
609+ TransformOp transformer, Ty init, int reduce_num,
610+ int left_num, int block_size, bool reduce_lastdim,
611+ const IndexCalculator& reduce_index_calculator,
612+ const IndexCalculator& left_index_calculator) {
613+ int input_idx, left_idx, stride;
591614 // the last dim gets involved in reduction
592- if (reduce_dim[ReduceRank - 1 ] == Rank - 1 ) {
593- int input_idx = blockIdx.y * blockDim.x + threadIdx.x ;
594- int left_idx = blockIdx.x ;
595- int stride = gridDim.y * blockDim.x ;
596- int input_offset = 0 ;
597- // calculate the offset, means the addr where each thread really start.
598- #pragma unroll
599- for (int i = 0 ; i < Rank - ReduceRank; ++i) {
600- auto divmod = left_divmoders[i].Divmod (left_idx);
601- input_offset += (divmod.val [0 ] * x_strides[left_dim[i]]);
602- left_idx = divmod.val [1 ];
603- }
604- const Tx* input = x + input_offset;
605- Ty reduce_var = init;
615+ if (reduce_lastdim) {
616+ input_idx = blockIdx.y * blockDim.x + threadIdx.x ;
617+ left_idx = blockIdx.x ;
618+ stride = gridDim.y * blockDim.x ;
619+ } else {
620+ input_idx = blockIdx.y * blockDim.y + threadIdx.y ;
621+ left_idx = blockIdx.x * blockDim.x + threadIdx.x ;
622+ stride = gridDim.y * blockDim.y ;
623+ }
624+ // calculate the offset, means the addr where each thread really start.
625+ int input_offset = left_index_calculator.Get (left_idx);
626+ const Tx* input = x + input_offset;
627+ Ty reduce_var = init;
606628
629+ // 1. reduce for each thread
630+ if (left_idx < left_num) {
607631 // load REDUCE_VEC data once, and then compute
608632 Tx inputs[REDUCE_VEC];
609633 int bound = reduce_num - (REDUCE_VEC - 1 ) * stride;
610634 while (input_idx < bound) {
611635#pragma unroll
612636 for (int i = 0 ; i < REDUCE_VEC; ++i) {
613637 int reduce_idx = input_idx + i * stride;
614- int idx_x = 0 ;
615- #pragma unroll
616- for (int j = 0 ; j < ReduceRank; ++j) {
617- auto divmod = reduce_divmoders[j].Divmod (reduce_idx);
618- idx_x += (divmod.val [0 ] * x_strides[reduce_dim[j]]);
619- reduce_idx = divmod.val [1 ];
620- }
638+ int idx_x = reduce_index_calculator.Get (reduce_idx);
621639 inputs[i] = input[idx_x];
622640 }
623641#pragma unroll
@@ -635,13 +653,7 @@ __device__ void ReduceAny(
635653 break ;
636654 }
637655 int reduce_idx = input_idx;
638- int idx_x = 0 ;
639- #pragma unroll
640- for (int j = 0 ; j < ReduceRank; ++j) {
641- auto divmod = reduce_divmoders[j].Divmod (reduce_idx);
642- idx_x += (divmod.val [0 ] * x_strides[reduce_dim[j]]);
643- reduce_idx = divmod.val [1 ];
644- }
656+ int idx_x = reduce_index_calculator.Get (reduce_idx);
645657 inputs[i] = input[idx_x];
646658 input_idx += stride;
647659 }
@@ -654,47 +666,12 @@ __device__ void ReduceAny(
654666 reduce_var = reducer (reduce_var, transformer (inputs[i]));
655667 input_idx += stride;
656668 }
669+ }
657670
658- __syncthreads ();
659-
660- reduce_var = BlockReduce (reduce_var, reducer);
661-
662- if (threadIdx.x == 0 ) {
663- y[blockIdx.x + blockIdx.y * gridDim.x ] = reduce_var;
664- }
665- } else {
671+ // 2. reduce in block y
672+ if (blockDim.y > 1 ) {
666673 // need shared_memory to reduce block y
667674 __shared__ Ty shared_memory[detail::kMaxThread ];
668- int left_idx = blockIdx.x * blockDim.x + threadIdx.x ;
669- int input_idx = (blockIdx.y * blockDim.y + threadIdx.y ) * block_size;
670- int left_idx_tmp = left_idx;
671- int input_offset = 0 ;
672- #pragma unroll
673- for (int i = 0 ; i < Rank - ReduceRank; ++i) {
674- auto divmod = left_divmoders[i].Divmod (left_idx_tmp);
675- input_offset += (divmod.val [0 ] * x_strides[left_dim[i]]);
676- left_idx_tmp = divmod.val [1 ];
677- }
678- const Tx* input = x + input_offset;
679- Ty reduce_var = init;
680-
681- if (left_idx < left_num) {
682- int loop = reduce_num - input_idx;
683- loop = loop > block_size ? block_size : loop;
684- for (int i = 0 ; i < loop; ++i) {
685- int reduce_idx = i + input_idx;
686- int idx_x = 0 ;
687- #pragma unroll
688- for (int j = 0 ; j < ReduceRank; ++j) {
689- auto divmod = reduce_divmoders[j].Divmod (reduce_idx);
690- idx_x += (divmod.val [0 ] * x_strides[reduce_dim[j]]);
691- reduce_idx = divmod.val [1 ];
692- }
693- reduce_var =
694- reducer (reduce_var, static_cast <Ty>(transformer (input[idx_x])));
695- }
696- }
697-
698675 shared_memory[SharedMemoryIndex (0 )] = reduce_var;
699676 for (int stride = blockDim.y / 2 ; stride > 0 ; stride >>= 1 ) {
700677 __syncthreads ();
@@ -704,7 +681,16 @@ __device__ void ReduceAny(
704681 }
705682 shared_memory[SharedMemoryIndex (0 )] = reduce_var;
706683 }
684+ }
685+ __syncthreads ();
707686
687+ if (reduce_lastdim) {
688+ // 3. reduce in block x
689+ reduce_var = BlockReduce (reduce_var, reducer);
690+ if (threadIdx.x == 0 ) {
691+ y[blockIdx.x + blockIdx.y * gridDim.x ] = reduce_var;
692+ }
693+ } else {
708694 if (left_idx < left_num && threadIdx.y == 0 ) {
709695 y[blockIdx.y * left_num + left_idx] = reduce_var;
710696 }
@@ -713,15 +699,13 @@ __device__ void ReduceAny(
713699
714700// module function designed for global function
715701template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
716- int Rank, int ReduceRank, int ReduceType>
717- __device__ void ReduceModule (
718- const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
719- int reduce_num, int left_num, int blocking_size,
720- paddle::framework::Array<int , Rank> x_strides,
721- paddle::framework::Array<int , ReduceRank> reduce_dim,
722- paddle::framework::Array<FastDivMod, ReduceRank> reduce_divmoders,
723- paddle::framework::Array<int , Rank - ReduceRank> left_dim,
724- paddle::framework::Array<FastDivMod, Rank - ReduceRank> left_divmoders) {
702+ int ReduceType>
703+ __device__ void ReduceModule (const Tx* x, Ty* y, ReduceOp reducer,
704+ TransformOp transformer, Ty init, int reduce_num,
705+ int left_num, int blocking_size,
706+ bool reduce_lastdim,
707+ const IndexCalculator& reduce_index_calculator,
708+ const IndexCalculator& left_index_calculator) {
725709 if (ReduceType == ReduceType::kReduceLastDim ) {
726710 ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
727711 init, reduce_num);
@@ -733,48 +717,52 @@ __device__ void ReduceModule(
733717
734718 // reduce_rank >= 2
735719 } else {
736- ReduceAny<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank >(
720+ ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
737721 x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
738- x_strides, reduce_dim, reduce_divmoders, left_dim, left_divmoders );
722+ reduce_lastdim, reduce_index_calculator, left_index_calculator );
739723 }
740724}
741725
742726template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
743- int Rank, int ReduceRank, int ReduceType>
744- __global__ void ReduceKernelFunction (
745- const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
746- int reduce_num, int left_num, int block_size,
747- paddle::framework::Array<int , Rank> x_strides,
748- paddle::framework::Array<int , ReduceRank> reduce_dim,
749- paddle::framework::Array<FastDivMod, ReduceRank> reduce_divmoders,
750- paddle::framework::Array<int , Rank - ReduceRank> left_dim,
751- paddle::framework::Array<FastDivMod, Rank - ReduceRank> left_divmoders) {
752- ReduceModule<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank, ReduceType>(
753- x, y, reducer, transformer, init, reduce_num, left_num, block_size,
754- x_strides, reduce_dim, reduce_divmoders, left_dim, left_divmoders);
727+ int ReduceType>
728+ __global__ void ReduceKernelFunction (const Tx* x, Ty* y, ReduceOp reducer,
729+ TransformOp transformer, Ty init,
730+ int reduce_num, int left_num,
731+ int blocking_size, bool reduce_lastdim,
732+ IndexCalculator reduce_index_calculator,
733+ IndexCalculator left_index_calculator) {
734+ ReduceModule<Tx, Ty, ReduceOp, TransformOp, ReduceType>(
735+ x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
736+ reduce_lastdim, reduce_index_calculator, left_index_calculator);
755737}
756738
757- template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank >
739+ template <typename Tx, typename Ty, typename ReduceOp>
758740static void LaunchReduceKernel (const Tx* x_data, Ty* y_data,
759741 const ReduceOp& reducer, Ty init,
760742 gpuStream_t stream, ReduceConfig<Ty> config) {
761743 using TransformOp = typename ReduceOp::Transformer;
762744
745+ int reduce_rank = config.reduce_strides .size ();
746+ int left_rank = config.left_strides .size ();
747+ auto reduce_index_calculator = IndexCalculator (
748+ reduce_rank, detail::VectorToArray<int , MAX_DIM>(config.x_strides ),
749+ detail::VectorToArray<int , MAX_DIM>(config.reduce_dim ),
750+ detail::VectorToArray<FastDivMod, MAX_DIM>(config.reduce_divmoders ));
751+ auto left_index_calculator = IndexCalculator (
752+ left_rank, detail::VectorToArray<int , MAX_DIM>(config.x_strides ),
753+ detail::VectorToArray<int , MAX_DIM>(config.left_dim ),
754+ detail::VectorToArray<FastDivMod, MAX_DIM>(config.left_divmoders ));
755+
763756#define CUB_REDUCE_TYPE_CASE (type ) \
764757 case type: { \
765758 constexpr auto kReduceType = type; \
766759 ReduceKernelFunction< \
767- Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank, \
760+ Tx, Ty, ReduceOp, TransformOp, \
768761 kReduceType ><<<config.grid , config.block , 0 , stream>>>( \
769762 x_data, config.output_data , reducer, TransformOp (config.reduce_num ), \
770763 init, config.reduce_num , config.left_num , config.blocking_size , \
771- detail::VectorToArray<int , Rank>(config.x_strides ), \
772- detail::VectorToArray<int , ReduceRank>(config.reduce_dim ), \
773- detail::VectorToArray<FastDivMod, ReduceRank>( \
774- config.reduce_divmoders ), \
775- detail::VectorToArray<int , Rank - ReduceRank>(config.left_dim ), \
776- detail::VectorToArray<FastDivMod, Rank - ReduceRank>( \
777- config.left_divmoders )); \
764+ config.reduce_lastdim , reduce_index_calculator, \
765+ left_index_calculator); \
778766 } break
779767
780768 switch (config.reduce_type ) {
@@ -795,61 +783,13 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
795783 }
796784
797785 ReduceKernelFunction<
798- Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank, ReduceRank,
786+ Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>,
799787 ReduceType::kReduceHigherDim ><<<grid, block, 0 , stream>>>(
800788 config.output_data , y_data, reducer,
801789 detail::IdentityFunctor<Ty>(config.grid .y ), init, config.grid .y ,
802- config.left_num , config.grid .y ,
803- detail::VectorToArray<int , Rank>(config.x_strides ),
804- detail::VectorToArray<int , ReduceRank>(config.reduce_dim ),
805- detail::VectorToArray<FastDivMod, ReduceRank>(config.reduce_divmoders ),
806- detail::VectorToArray<int , Rank - ReduceRank>(config.left_dim ),
807- detail::VectorToArray<FastDivMod, Rank - ReduceRank>(
808- config.left_divmoders ));
809- }
810- }
811-
812- template <typename Tx, typename Ty, typename ReduceOp>
813- static void ReduceKernelImpl (const Tx* x_data, Ty* y_data,
814- const ReduceOp& reducer, Ty init,
815- gpuStream_t stream, ReduceConfig<Ty> config) {
816- int reduce_rank = config.reduce_strides .size ();
817- int rank = config.x_strides .size ();
818-
819- #define CUB_RANK_CASE (i, ...) \
820- case i: { \
821- constexpr auto Rank = i; \
822- switch (reduce_rank) { __VA_ARGS__; } \
823- } break
824-
825- #define CUB_REDUCE_RANK_CASE (i, ...) \
826- case i: { \
827- constexpr auto ReduceRank = i; \
828- LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank>( \
829- x_data, y_data, reducer, init, stream, config); \
830- } break
831-
832- detail::CheckReduceRank (reduce_rank, rank);
833- switch (rank) {
834- CUB_RANK_CASE (2 , CUB_REDUCE_RANK_CASE (1 ););
835-
836- CUB_RANK_CASE (3 , CUB_REDUCE_RANK_CASE (1 ); CUB_REDUCE_RANK_CASE (2 ););
837-
838- CUB_RANK_CASE (4 , CUB_REDUCE_RANK_CASE (2 ););
839-
840- CUB_RANK_CASE (5 , CUB_REDUCE_RANK_CASE (2 ); CUB_REDUCE_RANK_CASE (3 ););
841-
842- CUB_RANK_CASE (6 , CUB_REDUCE_RANK_CASE (3 ););
843-
844- CUB_RANK_CASE (7 , CUB_REDUCE_RANK_CASE (3 ); CUB_REDUCE_RANK_CASE (4 ););
845-
846- CUB_RANK_CASE (8 , CUB_REDUCE_RANK_CASE (4 ););
847-
848- CUB_RANK_CASE (9 , CUB_REDUCE_RANK_CASE (4 ); CUB_REDUCE_RANK_CASE (5 ););
790+ config.left_num , config.grid .y , config.reduce_lastdim ,
791+ reduce_index_calculator, left_index_calculator);
849792 }
850-
851- #undef CUB_REDUCE_RANK_CASE
852- #undef CUB_RANK_CASE
853793}
854794
855795template <typename Tx, typename Ty,
@@ -899,8 +839,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
899839 return ;
900840 }
901841
902- ReduceKernelImpl <Tx, Ty, ReduceOp<Tx, Ty>>(x_data, y_data, reducer,
903- reducer.initial (), stream, config);
842+ LaunchReduceKernel <Tx, Ty, ReduceOp<Tx, Ty>>(
843+ x_data, y_data, reducer, reducer.initial (), stream, config);
904844}
905845
906846template <typename Tx, template <typename , typename > class ReduceOp >
0 commit comments