@@ -14,14 +14,18 @@ limitations under the License. */
1414
1515#include < thrust/fill.h>
1616#include " paddle/fluid/operators/controlflow/compare_all_op.h"
17+ #include " paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
1718#include " paddle/fluid/operators/reduce_ops/cub_reduce.h"
19+
20+ namespace ops = paddle::operators;
21+ namespace plat = paddle::platform;
22+
1823namespace paddle {
1924namespace operators {
2025
2126template <typename T>
2227struct IdentityFunctor {
2328 HOSTDEVICE explicit inline IdentityFunctor () {}
24-
2529 HOSTDEVICE inline T operator ()(const T& x) const { return x; }
2630};
2731
@@ -33,6 +37,24 @@ struct BitwiseAdd {
3337 return a & b;
3438 }
3539};
40+
41+ template <typename T, typename Enable = void >
42+ struct CudaEqualReduceFunctor {
43+ using ELEM_TYPE = T;
44+ HOSTDEVICE bool operator ()(const T args[]) const {
45+ return (args[0 ] == args[1 ]);
46+ }
47+ };
48+
49+ template <typename T>
50+ struct CudaEqualReduceFunctor <
51+ T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
52+ using ELEM_TYPE = T;
53+ HOSTDEVICE bool operator ()(const T args[]) const {
54+ return fabs (static_cast <double >(args[0 ] - args[1 ])) < 1e-8 ;
55+ }
56+ };
57+
3658template <typename DeviceContext, typename Functor>
3759class CompareReduceOpKernel
3860 : public framework::OpKernel<typename Functor::ELEM_TYPE> {
@@ -44,32 +66,22 @@ class CompareReduceOpKernel
4466 auto * x = context.Input <Tensor>(" X" );
4567 auto * y = context.Input <Tensor>(" Y" );
4668 auto * z = context.Output <Tensor>(" Out" );
47- bool shape_same = true ;
48-
69+ bool * z_data = z->mutable_data <bool >(context.GetPlace ());
4970 Tensor tmp;
50- framework::DDim x_dims = x->dims ();
51- framework::DDim y_dims = y->dims ();
5271
53- if (x_dims.size () != y_dims.size ()) {
54- shape_same = false ;
55- } else {
56- for (auto i = 0 ; i < x_dims.size (); i++) {
57- if (x_dims[i] != y_dims[i]) {
58- shape_same = false ;
59- break ;
60- }
61- }
62- }
63-
64- bool * z_data = z->mutable_data <bool >(context.GetPlace ());
65- if (!shape_same) {
72+ if (x->dims () != y->dims ()) {
6673 thrust::device_ptr<bool > z_dev_ptr (z_data);
6774 thrust::fill (z_dev_ptr, z_dev_ptr + 1 , false );
6875 return ;
6976 } else {
70- tmp.mutable_data <bool >(x_dims, context.GetPlace ());
71- ElementwiseComputeEx<Functor, DeviceContext, T, bool >(context, x, y, 0 ,
72- Functor (), &tmp);
77+ tmp.mutable_data <bool >(x->dims (), context.GetPlace ());
78+ const auto & cuda_ctx =
79+ context.template device_context <platform::CUDADeviceContext>();
80+ std::vector<const framework::Tensor*> ins = {x, y};
81+ std::vector<framework::Tensor*> outs = {&tmp};
82+ LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary , T, bool >(
83+ cuda_ctx, ins, &outs, Functor ());
84+
7385 // Reduce by 'bitwise and' operator
7486 std::vector<int > reduce_dims;
7587 reduce_dims.resize (tmp.dims ().size ());
@@ -85,18 +97,17 @@ class CompareReduceOpKernel
8597} // namespace operators
8698} // namespace paddle
8799
88- #define REGISTER_COMPARE_REDUCE_CUDA_KERNEL (op_type, functor ) \
89- REGISTER_OP_CUDA_KERNEL ( \
90- op_type, paddle::operators::CompareReduceOpKernel< \
91- paddle::platform::CUDADeviceContext, functor<bool >>, \
92- paddle::operators::CompareReduceOpKernel< \
93- paddle::platform::CUDADeviceContext, functor<int >>, \
94- paddle::operators::CompareReduceOpKernel< \
95- paddle::platform::CUDADeviceContext, functor<int64_t >>, \
96- paddle::operators::CompareReduceOpKernel< \
97- paddle::platform::CUDADeviceContext, functor<float >>, \
98- paddle::operators::CompareReduceOpKernel< \
99- paddle::platform::CUDADeviceContext, functor<double >>);
100-
101- REGISTER_COMPARE_REDUCE_CUDA_KERNEL (equal_all,
102- paddle::operators::EqualReduceFunctor);
100+ #define REGISTER_COMPARE_REDUCE_CUDA_KERNEL (op_type, functor ) \
101+ REGISTER_OP_CUDA_KERNEL ( \
102+ op_type, \
103+ ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool >>, \
104+ ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int >>, \
105+ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
106+ ops::functor<int64_t >>, \
107+ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
108+ ops::functor<float >>, \
109+ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
110+ ops::functor<double >>);
111+
112+ REGISTER_COMPARE_REDUCE_CUDA_KERNEL (equal_all, CudaEqualReduceFunctor)
113+ #undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
0 commit comments