@@ -33,6 +33,7 @@ namespace cub = hipcub;
3333#include " paddle/fluid/framework/op_registry.h"
3434#include " paddle/fluid/framework/tensor.h"
3535#include " paddle/fluid/framework/tensor_util.h"
36+ #include " paddle/fluid/operators/amp/fp16_type_traits.h"
3637#include " paddle/fluid/platform/cuda_device_function.h"
3738#include " paddle/fluid/platform/fast_divmod.h"
3839
@@ -145,7 +146,6 @@ using Tensor = framework::Tensor;
145146constexpr int kMaxRank = framework::DDim::kMaxRank ;
146147
147148enum ReduceType {
148- kReduceAll = 0x00 , // when reduce_rank == x_rank
149149 kReduceLastDim = 0x01 , // when reduce_dim[0] == x_dim.size() - 1;
150150 kReduceHigherDim = 0x02 , // ReduceFirstDim or reduceSecondDim
151151 kReduceAny = 0x03 , // when reduce_dim.size() > 1
@@ -341,9 +341,8 @@ struct ReduceConfig {
341341 bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2 ) ||
342342 (left_num > REDUCE_SPLIT_BOUNDARY);
343343
344- if (rank == reduce_rank) {
345- reduce_type = static_cast <int >(ReduceType::kReduceAll );
346- } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0 ] == 1 ) {
344+ if (rank == reduce_rank ||
345+ rank == 2 && reduce_rank == 1 && reduce_dim[0 ] == 1 ) {
347346 reduce_type = static_cast <int >(ReduceType::kReduceLastDim );
348347 } else if (reduce_rank == 1 &&
349348 ((rank == 2 && is_large_enough) || rank != 2 )) {
@@ -576,36 +575,37 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
576575// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
577576// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
578577// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
579- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
578+ template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
579+ typename TransformOp>
580580__device__ void ReduceHigherDim (const Tx* x, Ty* y, ReduceOp reducer,
581- TransformOp transformer, Ty init,
581+ TransformOp transformer, MPType init,
582582 int reduce_num, int left_num, int block_size) {
583583 int idx = blockIdx.x * blockDim.x + threadIdx.x ;
584584 int idy = blockIdx.y * block_size;
585585
586- Ty reduce_var = init;
586+ MPType reduce_var = init;
587587
588588 if (idx < left_num) {
589589 int loop = reduce_num - idy;
590590 loop = loop > block_size ? block_size : loop;
591591
592592 for (int iy = 0 ; iy < loop; iy++) {
593593 int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
594- reduce_var = reducer (reduce_var, static_cast <Ty >(transformer (x[id])));
594+ reduce_var = reducer (reduce_var, static_cast <MPType >(transformer (x[id])));
595595 }
596596
597597 y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
598- reduce_var;
598+ static_cast <Ty>( reduce_var) ;
599599 }
600600}
601601
602602// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
603603// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
604604// function will be used
605- template <typename Tx, typename Ty, typename ReduceOp , typename TransformOp ,
606- typename ReduceIndexCal, typename LeftIndexCal>
605+ template <typename Tx, typename Ty, typename MPType , typename ReduceOp ,
606+ typename TransformOp, typename ReduceIndexCal, typename LeftIndexCal>
607607__device__ void ReduceAny (const Tx* x, Ty* y, ReduceOp reducer,
608- TransformOp transformer, Ty init, int reduce_num,
608+ TransformOp transformer, MPType init, int reduce_num,
609609 int left_num, bool reduce_lastdim,
610610 ReduceIndexCal reduce_index_calculator,
611611 LeftIndexCal left_index_calculator) {
@@ -623,7 +623,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
623623 // calculate the offset, means the addr where each thread really start.
624624 int input_offset = left_index_calculator (left_idx);
625625 const Tx* input = x + input_offset;
626- Ty reduce_var = init;
626+ MPType reduce_var = init;
627627
628628 // 1. reduce for each thread
629629 if (left_idx < left_num) {
@@ -639,7 +639,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
639639 }
640640#pragma unroll
641641 for (int i = 0 ; i < REDUCE_VEC_SIZE; ++i) {
642- reduce_var = reducer (reduce_var, transformer (input_reg[i]));
642+ reduce_var =
643+ reducer (reduce_var, static_cast <MPType>(transformer (input_reg[i])));
643644 }
644645 input_idx += REDUCE_VEC_SIZE * stride;
645646 }
@@ -662,7 +663,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
662663 if (input_idx >= reduce_num) {
663664 break ;
664665 }
665- reduce_var = reducer (reduce_var, transformer (input_reg[i]));
666+ reduce_var =
667+ reducer (reduce_var, static_cast <MPType>(transformer (input_reg[i])));
666668 input_idx += stride;
667669 }
668670 }
@@ -677,71 +679,72 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
677679 // 3. reduce in block x
678680 reduce_var = BlockXReduce (reduce_var, reducer);
679681 if (left_idx < left_num && threadIdx.x == 0 ) {
680- y[blockIdx.y * left_num + left_idx] = reduce_var;
682+ y[blockIdx.y * left_num + left_idx] = static_cast <Ty>( reduce_var) ;
681683 }
682684 } else {
683685 if (left_idx < left_num && threadIdx.y == 0 ) {
684- y[blockIdx.y * left_num + left_idx] = reduce_var;
686+ y[blockIdx.y * left_num + left_idx] = static_cast <Ty>( reduce_var) ;
685687 }
686688 }
687689}
688690
689691// module function designed for global function
690- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
692+ template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
693+ typename TransformOp>
691694__device__ void ReduceModule (const Tx* x, Ty* y, ReduceOp reducer,
692- TransformOp transformer, Ty init, int reduce_num ,
693- int left_num , int blocking_size , int reduce_type ,
694- bool reduce_lastdim,
695+ TransformOp transformer, MPType init,
696+ int reduce_num , int left_num , int blocking_size ,
697+ int reduce_type, bool reduce_lastdim,
695698 const IndexCalculator& reduce_index_calculator,
696699 const IndexCalculator& left_index_calculator) {
697700 if (reduce_type == ReduceType::kReduceLastDim ) {
698- ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
701+ ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
699702 x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
700703 [&](int idx) { return idx; },
701704 [&](int idx) { return idx * reduce_num; });
702705
703706 // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
704707 } else if (reduce_type == ReduceType::kReduceHigherDim ) {
705- ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
708+ ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>(
706709 x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
707710
708711 // reduce_rank >= 2
709712 } else {
710- ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
713+ ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
711714 x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
712715 [&](int idx) { return reduce_index_calculator.Get (idx); },
713716 [&](int idx) { return left_index_calculator.Get (idx); });
714717 }
715718}
716719
717- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
720+ template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
721+ typename TransformOp>
718722__global__ void ReduceKernelFunction (const Tx* x, Ty* y, ReduceOp reducer,
719- TransformOp transformer, Ty init,
723+ TransformOp transformer, MPType init,
720724 int reduce_num, int left_num,
721725 int blocking_size, int reduce_type,
722726 bool reduce_lastdim,
723727 IndexCalculator reduce_index_calculator,
724728 IndexCalculator left_index_calculator) {
725- ReduceModule<Tx, Ty, ReduceOp, TransformOp>(
729+ ReduceModule<Tx, Ty, MPType, ReduceOp, TransformOp>(
726730 x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
727731 reduce_type, reduce_lastdim, reduce_index_calculator,
728732 left_index_calculator);
729733}
730734
731- template <typename Tx, typename Ty, typename ReduceOp>
735+ template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
732736static void LaunchReduceKernel (const Tx* x_data, Ty* y_data,
733- const ReduceOp& reducer, Ty init,
737+ const ReduceOp& reducer, MPType init,
734738 gpuStream_t stream, ReduceConfig<Ty> config) {
735739 using TransformOp = typename ReduceOp::Transformer;
736-
737740 int reduce_rank = config.reduce_strides .size ();
738741 int left_rank = config.left_strides .size ();
739742 auto reduce_index_calculator = IndexCalculator (
740743 reduce_rank, config.reduce_dim , config.reduce_strides , config.x_strides );
741744 auto left_index_calculator = IndexCalculator (
742745 left_rank, config.left_dim , config.left_strides , config.x_strides );
743746
744- ReduceKernelFunction<Tx, Ty, ReduceOp,
747+ ReduceKernelFunction<Tx, Ty, MPType, ReduceOp,
745748 TransformOp><<<config.grid , config.block , 0 , stream>>>(
746749 x_data, config.output_data , reducer, TransformOp (config.reduce_num ), init,
747750 config.reduce_num , config.left_num , config.blocking_size ,
@@ -759,10 +762,11 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
759762 grid = dim3 (config.grid .x , 1 , config.grid .z );
760763 }
761764
762- ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<
763- Ty>><<<grid, block, 0 , stream>>>(
765+ ReduceKernelFunction<
766+ Ty, Ty, MPType, ReduceOp,
767+ detail::IdentityFunctor<Ty, MPType>><<<grid, block, 0 , stream>>>(
764768 config.output_data , y_data, reducer,
765- detail::IdentityFunctor<Ty>(config.grid .y ), init, config.grid .y ,
769+ detail::IdentityFunctor<Ty, MPType >(config.grid .y ), init, config.grid .y ,
766770 config.left_num , config.grid .y , ReduceType::kReduceHigherDim ,
767771 config.reduce_lastdim , reduce_index_calculator, left_index_calculator);
768772 }
@@ -793,11 +797,12 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
793797 }
794798
795799 config.SetOutputData (y_data, x.place (), &tmp);
796-
797- using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
798- auto reducer = ReduceOp<Tx, Ty>();
799- // launch CUB::Reduce
800- if (config.reduce_type == static_cast <int >(ReduceType::kReduceAll )) {
800+ bool use_cub_Reduce = (config.left_num == 1 ) &&
801+ (!std::is_same<Tx, paddle::platform::float16>::value);
802+ if (use_cub_Reduce) {
803+ // launch CUB::Reduce
804+ using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
805+ auto reducer = ReduceOp<Tx, Ty>();
801806 cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x (
802807 x_data, TransformOp (config.reduce_num ));
803808 size_t temp_storage_bytes = 0 ;
@@ -815,7 +820,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
815820 return ;
816821 }
817822
818- LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>(
823+ using MPType = typename details::MPTypeTrait<Ty>::Type;
824+ auto reducer = ReduceOp<Tx, MPType>();
825+ LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
819826 x_data, y_data, reducer, reducer.initial (), stream, config);
820827}
821828
0 commit comments