@@ -87,8 +87,10 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
8787
8888#ifdef __HIPCC__
8989constexpr int kMaxThread = 256 ;
90+ constexpr int warp_size = 64 ;
9091#else
9192constexpr int kMaxThread = 128 ;
93+ constexpr int warp_size = 32 ;
9294#endif
9395
9496// get blockDim for reduceLastDim and reduceAny
@@ -393,63 +395,31 @@ struct ReduceConfig {
393395 dim3 grid;
394396};
395397
396- // version 1
397- // template <typename T, typename ReduceOp>
398- // __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
399- // constexpr int warp_size = 32;
400- // unsigned mask = 0u;
401- // CREATE_SHFL_MASK(mask, true);
402- // for (int stride = warp_size / 2; stride > 0; stride >>= 1) {
403- // T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
404- // val = reducer(val, temp);
405- // }
406- // return val;
407- // }
408-
409- // template <typename T, typename ReduceOp>
410- // __device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
411- // __shared__ T shared[32];
412- // int block_dim_x = blockDim.x;
413- // if (blockDim.x > warpSize) {
414- // block_dim_x = blockDim.x / warpSize;
415- // int lane = threadIdx.x % warpSize;
416- // int wid = threadIdx.x / warpSize;
417- // val = WarpReduce(val, reducer);
418- // if (lane == 0) {
419- // shared[wid] = val;
420- // }
421- // __syncthreads();
422- // if (wid == 0) {
423- // val = shared[lane];
424- // }
425- // }
426- // __syncthreads();
427-
428- // unsigned mask = 0u;
429- // CREATE_SHFL_MASK(mask, true);
430- // for (int stride = 1; stride < block_dim_x; stride <<= 1) {
431- // T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
432- // val = reducer(val, temp);
433- // }
434- // return val;
435- // }
436-
437- // version 2
398+ template <typename T, typename ReduceOp>
399+ __device__ __forceinline__ T WarpReduce (T val, ReduceOp reducer) {
400+ unsigned mask = 0u ;
401+ CREATE_SHFL_MASK (mask, true );
402+ for (int stride = detail::warp_size / 2 ; stride > 0 ; stride >>= 1 ) {
403+ T temp = paddle::platform::CudaShuffleDownSync (mask, val, stride);
404+ val = reducer (val, temp);
405+ }
406+ return val;
407+ }
408+
438409template <typename T, typename ReduceOp>
439410__device__ __forceinline__ T BlockReduce (T val, ReduceOp reducer) {
440- __shared__ T shared[detail::kMaxThread ];
411+ __shared__ T shared[detail::warp_size ];
441412 int block_dim_x = blockDim.x ;
442413 if (blockDim.x > warpSize) {
443- block_dim_x = warpSize;
444- shared[threadIdx.x ] = val;
445- for (int stride = blockDim.x / 2 ; stride >= warpSize; stride >>= 1 ) {
446- __syncthreads ();
447- if (threadIdx.x < stride) {
448- T temp = shared[threadIdx.x + stride];
449- val = reducer (val, temp);
450- shared[threadIdx.x ] = val;
451- }
414+ block_dim_x = blockDim.x / warpSize;
415+ int lane = threadIdx.x % warpSize;
416+ int wid = threadIdx.x / warpSize;
417+ val = WarpReduce (val, reducer);
418+ if (lane == 0 ) {
419+ shared[wid] = val;
452420 }
421+ __syncthreads ();
422+ val = shared[lane];
453423 }
454424 __syncthreads ();
455425
@@ -616,37 +586,36 @@ __global__ void ReduceKernelFunction(
616586 left_strides);
617587}
618588
619- template <typename Tx, typename Ty, typename ReduceOp, int kRank ,
620- int kReduceRank >
589+ template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank>
621590static void LaunchReduceKernel (const Tx* x_data, Ty* y_data,
622591 const ReduceOp& reducer, Ty init,
623592 gpuStream_t stream, ReduceConfig<Ty> config) {
624593 using TransformOp = typename ReduceOp::Transformer;
625594
626- ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, kRank ,
627- kReduceRank ><<<config.grid , config.block , 0 , stream>>>(
595+ ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, Rank ,
596+ ReduceRank ><<<config.grid , config.block , 0 , stream>>>(
628597 x_data, config.output_data , reducer, TransformOp (config.reduce_num ), init,
629598 config.reduce_num , config.left_num , config.blocking_size ,
630- config.reduce_type , detail::VectorToArray<int , kRank >(config.x_strides ),
631- detail::VectorToArray<int , kReduceRank >(config.reduce_dim ),
632- detail::VectorToArray<int , kReduceRank >(config.reduce_strides ),
633- detail::VectorToArray<int , kRank - kReduceRank >(config.left_dim ),
634- detail::VectorToArray<int , kRank - kReduceRank >(config.left_strides ));
599+ config.reduce_type , detail::VectorToArray<int , Rank >(config.x_strides ),
600+ detail::VectorToArray<int , ReduceRank >(config.reduce_dim ),
601+ detail::VectorToArray<int , ReduceRank >(config.reduce_strides ),
602+ detail::VectorToArray<int , Rank - ReduceRank >(config.left_dim ),
603+ detail::VectorToArray<int , Rank - ReduceRank >(config.left_strides ));
635604
636605 if (config.should_reduce_again ) {
637606 dim3 block (config.block .x , 1 , 1 );
638607 dim3 grid (config.grid .x , 1 , config.grid .z );
639608
640- ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, kRank ,
641- kReduceRank ><<<grid, block, 0 , stream>>>(
609+ ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank ,
610+ ReduceRank ><<<grid, block, 0 , stream>>>(
642611 config.output_data , y_data, reducer,
643612 detail::IdentityFunctor<Ty>(config.grid .y ), init, config.grid .y ,
644613 config.left_num , config.grid .y , ReduceType::kReduceHigherDim ,
645- detail::VectorToArray<int , kRank >(config.x_strides ),
646- detail::VectorToArray<int , kReduceRank >(config.reduce_dim ),
647- detail::VectorToArray<int , kReduceRank >(config.reduce_strides ),
648- detail::VectorToArray<int , kRank - kReduceRank >(config.left_dim ),
649- detail::VectorToArray<int , kRank - kReduceRank >(config.left_strides ));
614+ detail::VectorToArray<int , Rank >(config.x_strides ),
615+ detail::VectorToArray<int , ReduceRank >(config.reduce_dim ),
616+ detail::VectorToArray<int , ReduceRank >(config.reduce_strides ),
617+ detail::VectorToArray<int , Rank - ReduceRank >(config.left_dim ),
618+ detail::VectorToArray<int , Rank - ReduceRank >(config.left_strides ));
650619 }
651620}
652621
@@ -659,15 +628,15 @@ static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
659628
660629#define CUB_RANK_CASE (i, ...) \
661630 case i: { \
662- constexpr auto kRank = i; \
631+ constexpr auto Rank = i; \
663632 switch (reduce_rank) { __VA_ARGS__; } \
664633 } break
665634
666- #define CUB_REDUCE_RANK_CASE (i, ...) \
667- case i: { \
668- constexpr auto kReduceRank = i; \
669- LaunchReduceKernel<Tx, Ty, ReduceOp, kRank , kReduceRank >( \
670- x_data, y_data, reducer, init, stream, config); \
635+ #define CUB_REDUCE_RANK_CASE (i, ...) \
636+ case i: { \
637+ constexpr auto ReduceRank = i; \
638+ LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank >( \
639+ x_data, y_data, reducer, init, stream, config); \
671640 } break
672641
673642 detail::CheckReduceRank (reduce_rank, rank);
0 commit comments