Skip to content

Commit d9585f9

Browse files
authored
Merge pull request #4632 from luotao1/reduce
Unify Reduce functions and simplify register code
2 parents 3a68955 + 5972990 commit d9585f9

File tree

5 files changed

+35
-63
lines changed

5 files changed

+35
-63
lines changed

paddle/operators/activation_op.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,9 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
285285
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
286286
REGISTER_OP_CPU_KERNEL( \
287287
act_type, \
288-
paddle::operators::ActivationKernel<paddle::platform::CPUPlace, \
289-
paddle::operators::functor<float>>); \
288+
ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>); \
290289
REGISTER_OP_CPU_KERNEL(act_type##_grad, \
291-
paddle::operators::ActivationGradKernel< \
292-
paddle::platform::CPUPlace, \
293-
paddle::operators::grad_functor<float>>);
290+
ops::ActivationGradKernel<paddle::platform::CPUPlace, \
291+
ops::grad_functor<float>>);
294292

295293
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);

paddle/operators/activation_op.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
#define EIGEN_USE_GPU
1616
#include "paddle/operators/activation_op.h"
1717

18+
namespace ops = paddle::operators;
19+
1820
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
1921
REGISTER_OP_GPU_KERNEL( \
2022
act_type, \
21-
paddle::operators::ActivationKernel<paddle::platform::GPUPlace, \
22-
paddle::operators::functor<float>>); \
23+
ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>); \
2324
REGISTER_OP_GPU_KERNEL(act_type##_grad, \
24-
paddle::operators::ActivationGradKernel< \
25-
paddle::platform::GPUPlace, \
26-
paddle::operators::grad_functor<float>>);
25+
ops::ActivationGradKernel<paddle::platform::GPUPlace, \
26+
ops::grad_functor<float>>);
2727

2828
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);

paddle/operators/reduce_op.cc

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -168,36 +168,22 @@ namespace ops = paddle::operators;
168168

169169
REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad,
170170
ops::ReduceGradOp);
171-
REGISTER_OP_CPU_KERNEL(
172-
reduce_sum,
173-
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::SumFunctor>);
174-
REGISTER_OP_CPU_KERNEL(reduce_sum_grad,
175-
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
176-
ops::SumGradFunctor>);
177171

178172
REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
179173
reduce_mean_grad, ops::ReduceGradOp);
180-
REGISTER_OP_CPU_KERNEL(
181-
reduce_mean,
182-
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MeanFunctor>);
183-
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
184-
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
185-
ops::MeanGradFunctor>);
186174

187175
REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
188176
ops::ReduceGradOp);
189-
REGISTER_OP_CPU_KERNEL(
190-
reduce_max,
191-
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MaxFunctor>);
192-
REGISTER_OP_CPU_KERNEL(reduce_max_grad,
193-
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
194-
ops::MaxOrMinGradFunctor>);
195-
196-
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad,
177+
178+
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad,
197179
ops::ReduceGradOp);
198-
REGISTER_OP_CPU_KERNEL(
199-
reduce_min,
200-
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MinFunctor>);
201-
REGISTER_OP_CPU_KERNEL(reduce_min_grad,
202-
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
203-
ops::MaxOrMinGradFunctor>);
180+
181+
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
182+
REGISTER_OP_CPU_KERNEL( \
183+
reduce_type, \
184+
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::functor>); \
185+
REGISTER_OP_CPU_KERNEL(reduce_type##_grad, \
186+
ops::ReduceGradKernel<paddle::platform::CPUPlace, \
187+
float, ops::grad_functor>);
188+
189+
FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_CPU_KERNEL);

paddle/operators/reduce_op.cu

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,12 @@
1717

1818
namespace ops = paddle::operators;
1919

20-
REGISTER_OP_GPU_KERNEL(
21-
reduce_sum,
22-
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::SumFunctor>);
23-
REGISTER_OP_GPU_KERNEL(reduce_sum_grad,
24-
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
25-
ops::SumGradFunctor>);
26-
27-
REGISTER_OP_GPU_KERNEL(
28-
reduce_mean,
29-
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MeanFunctor>);
30-
REGISTER_OP_GPU_KERNEL(reduce_mean_grad,
31-
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
32-
ops::MeanGradFunctor>);
33-
34-
REGISTER_OP_GPU_KERNEL(
35-
reduce_max,
36-
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MaxFunctor>);
37-
REGISTER_OP_GPU_KERNEL(reduce_max_grad,
38-
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
39-
ops::MaxOrMinGradFunctor>);
40-
41-
REGISTER_OP_GPU_KERNEL(
42-
reduce_min,
43-
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MinFunctor>);
44-
REGISTER_OP_GPU_KERNEL(reduce_min_grad,
45-
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
46-
ops::MaxOrMinGradFunctor>);
20+
#define REGISTER_REDUCE_GPU_KERNEL(reduce_type, functor, grad_functor) \
21+
REGISTER_OP_GPU_KERNEL( \
22+
reduce_type, \
23+
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::functor>); \
24+
REGISTER_OP_GPU_KERNEL(reduce_type##_grad, \
25+
ops::ReduceGradKernel<paddle::platform::GPUPlace, \
26+
float, ops::grad_functor>);
27+
28+
FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_GPU_KERNEL);

paddle/operators/reduce_op.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,9 @@ class ReduceGradKernel : public framework::OpKernel<T> {
198198

199199
} // namespace operators
200200
} // namespace paddle
201+
202+
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
203+
__macro(reduce_sum, SumFunctor, SumGradFunctor); \
204+
__macro(reduce_mean, MeanFunctor, MeanGradFunctor); \
205+
__macro(reduce_max, MaxFunctor, MaxOrMinGradFunctor); \
206+
__macro(reduce_min, MinFunctor, MaxOrMinGradFunctor);

0 commit comments

Comments
 (0)