Skip to content

Commit ea1a0d4

Browse files
authored
Replace usage of elementwise cuda forward kernel in Compare_all_op (#33754)
1 parent 4d16724 commit ea1a0d4

File tree

3 files changed

+50
-56
lines changed

3 files changed

+50
-56
lines changed

paddle/fluid/operators/controlflow/compare_all_op.cc

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,13 @@ class CompareReduceOpKernel
3030
auto* x = context.Input<Tensor>("X");
3131
auto* y = context.Input<Tensor>("Y");
3232
auto* z = context.Output<Tensor>("Out");
33-
bool shape_same = true;
34-
3533
Tensor tmp;
36-
framework::DDim x_dims = x->dims();
37-
framework::DDim y_dims = y->dims();
38-
39-
// judge the two inputs shape is same, if not same, just return false
40-
if (x_dims.size() != y_dims.size()) {
41-
shape_same = false;
42-
} else {
43-
for (auto i = 0; i < x_dims.size(); i++) {
44-
if (x_dims[i] != y_dims[i]) {
45-
shape_same = false;
46-
break;
47-
}
48-
}
49-
}
50-
5134
bool* z_data = z->mutable_data<bool>(context.GetPlace());
52-
if (!shape_same) {
35+
36+
if (x->dims() != y->dims()) {
5337
z_data[0] = false;
5438
} else {
55-
tmp.mutable_data<bool>(x_dims, context.GetPlace());
39+
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
5640
if (x->numel() == 1 && y->numel() == 1) {
5741
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
5842
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);

paddle/fluid/operators/controlflow/compare_all_op.cu

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@ limitations under the License. */
1414

1515
#include <thrust/fill.h>
1616
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
17+
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
1718
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
19+
20+
namespace ops = paddle::operators;
21+
namespace plat = paddle::platform;
22+
1823
namespace paddle {
1924
namespace operators {
2025

2126
template <typename T>
2227
struct IdentityFunctor {
2328
HOSTDEVICE explicit inline IdentityFunctor() {}
24-
2529
HOSTDEVICE inline T operator()(const T& x) const { return x; }
2630
};
2731

@@ -33,6 +37,24 @@ struct BitwiseAdd {
3337
return a & b;
3438
}
3539
};
40+
41+
template <typename T, typename Enable = void>
42+
struct CudaEqualReduceFunctor {
43+
using ELEM_TYPE = T;
44+
HOSTDEVICE bool operator()(const T args[]) const {
45+
return (args[0] == args[1]);
46+
}
47+
};
48+
49+
template <typename T>
50+
struct CudaEqualReduceFunctor<
51+
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
52+
using ELEM_TYPE = T;
53+
HOSTDEVICE bool operator()(const T args[]) const {
54+
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
55+
}
56+
};
57+
3658
template <typename DeviceContext, typename Functor>
3759
class CompareReduceOpKernel
3860
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
@@ -44,32 +66,22 @@ class CompareReduceOpKernel
4466
auto* x = context.Input<Tensor>("X");
4567
auto* y = context.Input<Tensor>("Y");
4668
auto* z = context.Output<Tensor>("Out");
47-
bool shape_same = true;
48-
69+
bool* z_data = z->mutable_data<bool>(context.GetPlace());
4970
Tensor tmp;
50-
framework::DDim x_dims = x->dims();
51-
framework::DDim y_dims = y->dims();
5271

53-
if (x_dims.size() != y_dims.size()) {
54-
shape_same = false;
55-
} else {
56-
for (auto i = 0; i < x_dims.size(); i++) {
57-
if (x_dims[i] != y_dims[i]) {
58-
shape_same = false;
59-
break;
60-
}
61-
}
62-
}
63-
64-
bool* z_data = z->mutable_data<bool>(context.GetPlace());
65-
if (!shape_same) {
72+
if (x->dims() != y->dims()) {
6673
thrust::device_ptr<bool> z_dev_ptr(z_data);
6774
thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
6875
return;
6976
} else {
70-
tmp.mutable_data<bool>(x_dims, context.GetPlace());
71-
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, 0,
72-
Functor(), &tmp);
77+
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
78+
const auto& cuda_ctx =
79+
context.template device_context<platform::CUDADeviceContext>();
80+
std::vector<const framework::Tensor*> ins = {x, y};
81+
std::vector<framework::Tensor*> outs = {&tmp};
82+
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
83+
cuda_ctx, ins, &outs, Functor());
84+
7385
// Reduce by 'bitwise and' operator
7486
std::vector<int> reduce_dims;
7587
reduce_dims.resize(tmp.dims().size());
@@ -85,18 +97,17 @@ class CompareReduceOpKernel
8597
} // namespace operators
8698
} // namespace paddle
8799

88-
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
89-
REGISTER_OP_CUDA_KERNEL( \
90-
op_type, paddle::operators::CompareReduceOpKernel< \
91-
paddle::platform::CUDADeviceContext, functor<bool>>, \
92-
paddle::operators::CompareReduceOpKernel< \
93-
paddle::platform::CUDADeviceContext, functor<int>>, \
94-
paddle::operators::CompareReduceOpKernel< \
95-
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
96-
paddle::operators::CompareReduceOpKernel< \
97-
paddle::platform::CUDADeviceContext, functor<float>>, \
98-
paddle::operators::CompareReduceOpKernel< \
99-
paddle::platform::CUDADeviceContext, functor<double>>);
100-
101-
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
102-
paddle::operators::EqualReduceFunctor);
100+
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
101+
REGISTER_OP_CUDA_KERNEL( \
102+
op_type, \
103+
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
104+
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
105+
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
106+
ops::functor<int64_t>>, \
107+
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
108+
ops::functor<float>>, \
109+
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
110+
ops::functor<double>>);
111+
112+
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
113+
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL

paddle/fluid/operators/controlflow/compare_op.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ struct CudaNotEqualFunctor<
5959
template <typename Functor, typename InverseFunctor>
6060
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
6161
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
62-
public:
6362
public:
6463
using InT = typename Functor::ELEMENT_TYPE;
6564
using OutT = bool;

0 commit comments

Comments
 (0)