@@ -168,36 +168,22 @@ namespace ops = paddle::operators;
168168
169169REGISTER_OP (reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad,
170170 ops::ReduceGradOp);
171- REGISTER_OP_CPU_KERNEL (
172- reduce_sum,
173- ops::ReduceKernel<paddle::platform::CPUPlace, float , ops::SumFunctor>);
174- REGISTER_OP_CPU_KERNEL (reduce_sum_grad,
175- ops::ReduceGradKernel<paddle::platform::CPUPlace, float ,
176- ops::SumGradFunctor>);
177171
178172REGISTER_OP (reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
179173 reduce_mean_grad, ops::ReduceGradOp);
180- REGISTER_OP_CPU_KERNEL (
181- reduce_mean,
182- ops::ReduceKernel<paddle::platform::CPUPlace, float , ops::MeanFunctor>);
183- REGISTER_OP_CPU_KERNEL (reduce_mean_grad,
184- ops::ReduceGradKernel<paddle::platform::CPUPlace, float ,
185- ops::MeanGradFunctor>);
186174
187175REGISTER_OP (reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
188176 ops::ReduceGradOp);
189- REGISTER_OP_CPU_KERNEL (
190- reduce_max,
191- ops::ReduceKernel<paddle::platform::CPUPlace, float , ops::MaxFunctor>);
192- REGISTER_OP_CPU_KERNEL (reduce_max_grad,
193- ops::ReduceGradKernel<paddle::platform::CPUPlace, float ,
194- ops::MaxOrMinGradFunctor>);
195-
196- REGISTER_OP (reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad,
177+
178+ REGISTER_OP (reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad,
197179 ops::ReduceGradOp);
198- REGISTER_OP_CPU_KERNEL (
199- reduce_min,
200- ops::ReduceKernel<paddle::platform::CPUPlace, float , ops::MinFunctor>);
201- REGISTER_OP_CPU_KERNEL (reduce_min_grad,
202- ops::ReduceGradKernel<paddle::platform::CPUPlace, float ,
203- ops::MaxOrMinGradFunctor>);
180+
181+ #define REGISTER_REDUCE_CPU_KERNEL (reduce_type, functor, grad_functor ) \
182+ REGISTER_OP_CPU_KERNEL ( \
183+ reduce_type, \
184+ ops::ReduceKernel<paddle::platform::CPUPlace, float , ops::functor>); \
185+ REGISTER_OP_CPU_KERNEL (reduce_type##_grad, \
186+ ops::ReduceGradKernel<paddle::platform::CPUPlace, \
187+ float , ops::grad_functor>);
188+
189+ FOR_EACH_KERNEL_FUNCTOR (REGISTER_REDUCE_CPU_KERNEL);
0 commit comments