@@ -13,18 +13,85 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include " paddle/fluid/operators/controlflow/compare_op.h"
16+ #include " paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1617
17- REGISTER_COMPARE_KERNEL (less_than, CUDA, paddle::operators::LessThanFunctor,
18- paddle::operators::GreaterThanFunctor);
19- REGISTER_COMPARE_KERNEL (less_equal, CUDA, paddle::operators::LessEqualFunctor,
20- paddle::operators::GreaterEqualFunctor);
21- REGISTER_COMPARE_KERNEL (greater_than, CUDA,
22- paddle::operators::GreaterThanFunctor,
23- paddle::operators::LessThanFunctor);
24- REGISTER_COMPARE_KERNEL (greater_equal, CUDA,
25- paddle::operators::GreaterEqualFunctor,
26- paddle::operators::LessEqualFunctor);
27- REGISTER_COMPARE_KERNEL (equal, CUDA, paddle::operators::EqualFunctor,
28- paddle::operators::EqualFunctor);
29- REGISTER_COMPARE_KERNEL (not_equal, CUDA, paddle::operators::NotEqualFunctor,
30- paddle::operators::NotEqualFunctor);
18+ namespace ops = paddle::operators;
19+ namespace plat = paddle::platform;
20+
21+ namespace paddle {
22+ namespace operators {
23+
24+ #define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (Func, op ) \
25+ template <typename T, typename Enable = void > \
26+ struct Func ##Functor { \
27+ using ELEMENT_TYPE = T; \
28+ inline HOSTDEVICE bool operator ()(const T* args) const { \
29+ return args[0 ] op args[1 ]; \
30+ } \
31+ };
32+
33+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaLessThan, <)
34+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaLessEqual, <=)
35+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaGreaterThan, >)
36+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaGreaterEqual, >=)
37+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaEqual, ==)
38+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaNotEqual, !=)
39+ #undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
40+
41+ template <typename T>
42+ struct CudaEqualFunctor <
43+ T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
44+ using ELEMENT_TYPE = T;
45+ HOSTDEVICE bool operator ()(const T* args) const {
46+ return fabs (static_cast <double >(args[0 ] - args[1 ])) < 1e-8 ;
47+ }
48+ };
49+
50+ template <typename T>
51+ struct CudaNotEqualFunctor <
52+ T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
53+ using ELEMENT_TYPE = T;
54+ HOSTDEVICE bool operator ()(const T* args) const {
55+ return fabs (static_cast <double >(args[0 ] - args[1 ])) > 1e-8 ;
56+ }
57+ };
58+
59+ template <typename Functor, typename InverseFunctor>
60+ class CompareOpKernel <platform::CUDADeviceContext, Functor, InverseFunctor>
61+ : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
62+ public:
63+ public:
64+ using InT = typename Functor::ELEMENT_TYPE;
65+ using OutT = bool ;
66+ void Compute (const framework::ExecutionContext& ctx) const override {
67+ auto functor = Functor ();
68+ std::vector<const framework::Tensor*> ins;
69+ std::vector<framework::Tensor*> outs;
70+
71+ PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
72+ LaunchElementwiseCudaKernel<ElementwiseType::kBinary , InT, OutT>(
73+ ctx, ins, &outs, functor);
74+ }
75+ };
76+
77+ } // namespace operators
78+ } // namespace paddle
79+
80+ #define REGISTER_CUDA_COMPARE_KERNEL (op_type, func ) \
81+ REGISTER_OP_CUDA_KERNEL ( \
82+ op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
83+ ops::func##Functor<int >, void >, \
84+ ops::CompareOpKernel<plat::CUDADeviceContext, \
85+ ops::func##Functor<int64_t >, void >, \
86+ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float >, \
87+ void >, \
88+ ops::CompareOpKernel<plat::CUDADeviceContext, \
89+ ops::func##Functor<double >, void >);
90+
91+ REGISTER_CUDA_COMPARE_KERNEL (equal, CudaEqual)
92+ REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
93+ REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
94+ REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
95+ REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
96+ REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
97+ #undef REGISTER_CUDA_COMPARE_KERNEL
0 commit comments