Skip to content

Commit 3dd8c9d

Browse files
committed
merge
1 parent 3a54149 commit 3dd8c9d

File tree

3 files changed

+66
-160
lines changed

3 files changed

+66
-160
lines changed

paddle/fluid/operators/reduce_ops/reduce_mean_op.cu

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,11 @@
1313
// limitations under the License.
1414

1515
#include <vector>
16-
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
16+
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
1717
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
18+
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
1819

19-
namespace paddle {
20-
namespace operators {
21-
22-
template <typename T>
23-
struct DivideFunctor {
24-
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
25-
26-
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
27-
28-
private:
29-
T n_inv;
30-
};
31-
32-
template <typename T>
33-
class ReduceMeanKernel : public framework::OpKernel<T> {
34-
public:
35-
void Compute(const framework::ExecutionContext& context) const override {
36-
bool reduce_all = context.Attr<bool>("reduce_all");
37-
auto* input = context.Input<Tensor>("X");
38-
auto* output = context.Output<Tensor>("Out");
39-
40-
auto dims = context.Attr<std::vector<int>>("dim");
41-
bool keep_dim = context.Attr<bool>("keep_dim");
42-
43-
std::vector<int> reduce_dims;
44-
if (reduce_all) {
45-
reduce_dims.resize(input->dims().size());
46-
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
47-
} else {
48-
for (auto e : dims) {
49-
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
50-
}
51-
}
52-
53-
int reduce_num = 1;
54-
for (int i = 0; i < reduce_dims.size(); ++i) {
55-
reduce_num *= input->dims()[reduce_dims[i]];
56-
}
57-
58-
auto stream = context.cuda_device_context().stream();
59-
TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
60-
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
61-
DivideFunctor<T>(reduce_num), stream);
62-
}
63-
};
64-
65-
} // namespace operators
66-
} // namespace paddle
67-
68-
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<bool>,
69-
ops::ReduceMeanKernel<float>,
70-
ops::ReduceMeanKernel<double>);
20+
REGISTER_OP_CUDA_KERNEL(
21+
reduce_mean, ops::ReduceCudaKernel<bool, paddle::operators::CustomMean>,
22+
ops::ReduceCudaKernel<float, paddle::operators::CustomMean>,
23+
ops::ReduceCudaKernel<double, paddle::operators::CustomMean>);

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace cub = hipcub;
3333
#include "paddle/fluid/framework/op_registry.h"
3434
#include "paddle/fluid/framework/tensor.h"
3535
#include "paddle/fluid/framework/tensor_util.h"
36+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
3637
#include "paddle/fluid/platform/cuda_device_function.h"
3738
#include "paddle/fluid/platform/fast_divmod.h"
3839

@@ -145,7 +146,6 @@ using Tensor = framework::Tensor;
145146
constexpr int kMaxRank = framework::DDim::kMaxRank;
146147

147148
enum ReduceType {
148-
kReduceAll = 0x00, // when reduce_rank == x_rank
149149
kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1;
150150
kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim
151151
kReduceAny = 0x03, // when reduce_dim.size() > 1
@@ -341,9 +341,8 @@ struct ReduceConfig {
341341
bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
342342
(left_num > REDUCE_SPLIT_BOUNDARY);
343343

344-
if (rank == reduce_rank) {
345-
reduce_type = static_cast<int>(ReduceType::kReduceAll);
346-
} else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
344+
if (rank == reduce_rank ||
345+
rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
347346
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
348347
} else if (reduce_rank == 1 &&
349348
((rank == 2 && is_large_enough) || rank != 2)) {
@@ -576,36 +575,37 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
576575
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
577576
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
578577
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
579-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
578+
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
579+
typename TransformOp>
580580
__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
581-
TransformOp transformer, Ty init,
581+
TransformOp transformer, MPType init,
582582
int reduce_num, int left_num, int block_size) {
583583
int idx = blockIdx.x * blockDim.x + threadIdx.x;
584584
int idy = blockIdx.y * block_size;
585585

586-
Ty reduce_var = init;
586+
MPType reduce_var = init;
587587

588588
if (idx < left_num) {
589589
int loop = reduce_num - idy;
590590
loop = loop > block_size ? block_size : loop;
591591

592592
for (int iy = 0; iy < loop; iy++) {
593593
int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
594-
reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
594+
reduce_var = reducer(reduce_var, static_cast<MPType>(transformer(x[id])));
595595
}
596596

597597
y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
598-
reduce_var;
598+
static_cast<Ty>(reduce_var);
599599
}
600600
}
601601

602602
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
603603
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
604604
// function will be used
605-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
606-
typename ReduceIndexCal, typename LeftIndexCal>
605+
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
606+
typename TransformOp, typename ReduceIndexCal, typename LeftIndexCal>
607607
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
608-
TransformOp transformer, Ty init, int reduce_num,
608+
TransformOp transformer, MPType init, int reduce_num,
609609
int left_num, bool reduce_lastdim,
610610
ReduceIndexCal reduce_index_calculator,
611611
LeftIndexCal left_index_calculator) {
@@ -623,7 +623,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
623623
// calculate the offset, means the addr where each thread really start.
624624
int input_offset = left_index_calculator(left_idx);
625625
const Tx* input = x + input_offset;
626-
Ty reduce_var = init;
626+
MPType reduce_var = init;
627627

628628
// 1. reduce for each thread
629629
if (left_idx < left_num) {
@@ -639,7 +639,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
639639
}
640640
#pragma unroll
641641
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
642-
reduce_var = reducer(reduce_var, transformer(input_reg[i]));
642+
reduce_var =
643+
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
643644
}
644645
input_idx += REDUCE_VEC_SIZE * stride;
645646
}
@@ -662,7 +663,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
662663
if (input_idx >= reduce_num) {
663664
break;
664665
}
665-
reduce_var = reducer(reduce_var, transformer(input_reg[i]));
666+
reduce_var =
667+
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
666668
input_idx += stride;
667669
}
668670
}
@@ -677,71 +679,72 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
677679
// 3. reduce in block x
678680
reduce_var = BlockXReduce(reduce_var, reducer);
679681
if (left_idx < left_num && threadIdx.x == 0) {
680-
y[blockIdx.y * left_num + left_idx] = reduce_var;
682+
y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
681683
}
682684
} else {
683685
if (left_idx < left_num && threadIdx.y == 0) {
684-
y[blockIdx.y * left_num + left_idx] = reduce_var;
686+
y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
685687
}
686688
}
687689
}
688690

689691
// module function designed for global function
690-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
692+
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
693+
typename TransformOp>
691694
__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
692-
TransformOp transformer, Ty init, int reduce_num,
693-
int left_num, int blocking_size, int reduce_type,
694-
bool reduce_lastdim,
695+
TransformOp transformer, MPType init,
696+
int reduce_num, int left_num, int blocking_size,
697+
int reduce_type, bool reduce_lastdim,
695698
const IndexCalculator& reduce_index_calculator,
696699
const IndexCalculator& left_index_calculator) {
697700
if (reduce_type == ReduceType::kReduceLastDim) {
698-
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
701+
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
699702
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
700703
[&](int idx) { return idx; },
701704
[&](int idx) { return idx * reduce_num; });
702705

703706
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
704707
} else if (reduce_type == ReduceType::kReduceHigherDim) {
705-
ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
708+
ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>(
706709
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
707710

708711
// reduce_rank >= 2
709712
} else {
710-
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
713+
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
711714
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
712715
[&](int idx) { return reduce_index_calculator.Get(idx); },
713716
[&](int idx) { return left_index_calculator.Get(idx); });
714717
}
715718
}
716719

717-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
720+
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
721+
typename TransformOp>
718722
__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
719-
TransformOp transformer, Ty init,
723+
TransformOp transformer, MPType init,
720724
int reduce_num, int left_num,
721725
int blocking_size, int reduce_type,
722726
bool reduce_lastdim,
723727
IndexCalculator reduce_index_calculator,
724728
IndexCalculator left_index_calculator) {
725-
ReduceModule<Tx, Ty, ReduceOp, TransformOp>(
729+
ReduceModule<Tx, Ty, MPType, ReduceOp, TransformOp>(
726730
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
727731
reduce_type, reduce_lastdim, reduce_index_calculator,
728732
left_index_calculator);
729733
}
730734

731-
template <typename Tx, typename Ty, typename ReduceOp>
735+
template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
732736
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
733-
const ReduceOp& reducer, Ty init,
737+
const ReduceOp& reducer, MPType init,
734738
gpuStream_t stream, ReduceConfig<Ty> config) {
735739
using TransformOp = typename ReduceOp::Transformer;
736-
737740
int reduce_rank = config.reduce_strides.size();
738741
int left_rank = config.left_strides.size();
739742
auto reduce_index_calculator = IndexCalculator(
740743
reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides);
741744
auto left_index_calculator = IndexCalculator(
742745
left_rank, config.left_dim, config.left_strides, config.x_strides);
743746

744-
ReduceKernelFunction<Tx, Ty, ReduceOp,
747+
ReduceKernelFunction<Tx, Ty, MPType, ReduceOp,
745748
TransformOp><<<config.grid, config.block, 0, stream>>>(
746749
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
747750
config.reduce_num, config.left_num, config.blocking_size,
@@ -759,10 +762,11 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
759762
grid = dim3(config.grid.x, 1, config.grid.z);
760763
}
761764

762-
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<
763-
Ty>><<<grid, block, 0, stream>>>(
765+
ReduceKernelFunction<
766+
Ty, Ty, MPType, ReduceOp,
767+
detail::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
764768
config.output_data, y_data, reducer,
765-
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
769+
detail::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
766770
config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
767771
config.reduce_lastdim, reduce_index_calculator, left_index_calculator);
768772
}
@@ -793,11 +797,12 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
793797
}
794798

795799
config.SetOutputData(y_data, x.place(), &tmp);
796-
797-
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
798-
auto reducer = ReduceOp<Tx, Ty>();
799-
// launch CUB::Reduce
800-
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
800+
bool use_cub_Reduce = (config.left_num == 1) &&
801+
(!std::is_same<Tx, paddle::platform::float16>::value);
802+
if (use_cub_Reduce) {
803+
// launch CUB::Reduce
804+
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
805+
auto reducer = ReduceOp<Tx, Ty>();
801806
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
802807
x_data, TransformOp(config.reduce_num));
803808
size_t temp_storage_bytes = 0;
@@ -815,7 +820,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
815820
return;
816821
}
817822

818-
LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>(
823+
using MPType = typename details::MPTypeTrait<Ty>::Type;
824+
auto reducer = ReduceOp<Tx, MPType>();
825+
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
819826
x_data, y_data, reducer, reducer.initial(), stream, config);
820827
}
821828

paddle/fluid/operators/reduce_ops/reduce_sum_op.cu

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,72 +11,18 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
15-
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
14+
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
15+
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
1616
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
17-
18-
namespace paddle {
19-
namespace operators {
20-
21-
template <typename Tout>
22-
struct IdentityFunctor {
23-
HOSTDEVICE explicit inline IdentityFunctor() {}
24-
25-
template <typename U>
26-
HOSTDEVICE inline Tout operator()(const U& x) const {
27-
return static_cast<Tout>(x);
28-
}
29-
};
30-
31-
template <typename T>
32-
class ReduceSumKernel : public framework::OpKernel<T> {
33-
public:
34-
void Compute(const framework::ExecutionContext& context) const override {
35-
bool reduce_all = context.Attr<bool>("reduce_all");
36-
auto* input = context.Input<Tensor>("X");
37-
auto* output = context.Output<Tensor>("Out");
38-
auto out_dtype = context.Attr<int>("out_dtype");
39-
40-
auto dims = context.Attr<std::vector<int>>("dim");
41-
bool keep_dim = context.Attr<bool>("keep_dim");
42-
43-
std::vector<int> reduce_dims;
44-
if (reduce_all) {
45-
reduce_dims.resize(input->dims().size());
46-
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
47-
} else {
48-
for (auto e : dims) {
49-
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
50-
}
51-
}
52-
53-
int reduce_num = 1;
54-
for (int i = 0; i < reduce_dims.size(); ++i) {
55-
reduce_num *= input->dims()[reduce_dims[i]];
56-
}
57-
58-
auto stream = context.cuda_device_context().stream();
59-
if (out_dtype >= 0) {
60-
framework::VisitDataTypeSmall(
61-
static_cast<framework::proto::VarType::Type>(out_dtype),
62-
TensorReduceFunctor<T, cub::Sum, IdentityFunctor>(
63-
*input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(),
64-
stream));
65-
} else {
66-
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
67-
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
68-
IdentityFunctor<T>(), stream);
69-
}
70-
}
71-
};
72-
73-
} // namespace operators
74-
} // namespace paddle
75-
7617
REGISTER_OP_CUDA_KERNEL(
77-
reduce_sum, ops::ReduceSumKernel<bool>, ops::ReduceSumKernel<float>,
78-
ops::ReduceSumKernel<double>,
79-
ops::ReduceSumKernel<paddle::platform::float16>, ops::ReduceSumKernel<int>,
80-
ops::ReduceSumKernel<int64_t>,
81-
ops::ReduceSumKernel<paddle::platform::complex<float>>,
82-
ops::ReduceSumKernel<paddle::platform::complex<double>>);
18+
reduce_sum, ops::ReduceCudaKernel<bool, paddle::operators::CustomSum>,
19+
ops::ReduceCudaKernel<float, paddle::operators::CustomSum>,
20+
ops::ReduceCudaKernel<double, paddle::operators::CustomSum>,
21+
ops::ReduceCudaKernel<paddle::platform::float16,
22+
paddle::operators::CustomSum>,
23+
ops::ReduceCudaKernel<int, paddle::operators::CustomSum>,
24+
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomSum>,
25+
ops::ReduceCudaKernel<paddle::platform::complex<float>,
26+
paddle::operators::CustomSum>,
27+
ops::ReduceCudaKernel<paddle::platform::complex<double>,
28+
paddle::operators::CustomSum>);

0 commit comments

Comments
 (0)