@@ -21,21 +21,21 @@ namespace plat = paddle::platform;
2121namespace paddle {
2222namespace operators {
2323
24- #define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (Func , op ) \
24+ #define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (func , op ) \
2525 template <typename T, typename Enable = void > \
26- struct Func ##Functor { \
26+ struct func { \
2727 using ELEMENT_TYPE = T; \
2828 inline HOSTDEVICE bool operator ()(const T* args) const { \
2929 return args[0 ] op args[1 ]; \
3030 } \
3131 };
3232
33- DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaLessThan , <)
34- DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaLessEqual , <=)
35- DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaGreaterThan , >)
36- DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaGreaterEqual , >=)
37- DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaEqual , ==)
38- DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaNotEqual , !=)
33+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaLessThanFunctor , <)
34+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaLessEqualFunctor , <=)
35+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaGreaterThanFunctor , >)
36+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaGreaterEqualFunctor , >=)
37+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaEqualFunctor , ==)
38+ DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT (CudaNotEqualFunctor , !=)
3939#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
4040
4141template <typename T>
@@ -67,10 +67,12 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
6767 auto functor = Functor ();
6868 std::vector<const framework::Tensor*> ins;
6969 std::vector<framework::Tensor*> outs;
70+ const auto & cuda_ctx =
71+ ctx.template device_context <platform::CUDADeviceContext>();
7072
71- PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
73+ int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
7274 LaunchElementwiseCudaKernel<ElementwiseType::kBinary , InT, OutT>(
73- ctx , ins, &outs, functor);
75+ cuda_ctx , ins, &outs, axis , functor);
7476 }
7577};
7678
@@ -79,19 +81,16 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
7981
8082#define REGISTER_CUDA_COMPARE_KERNEL (op_type, func ) \
8183 REGISTER_OP_CUDA_KERNEL ( \
82- op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
83- ops::func##Functor<int >, void >, \
84- ops::CompareOpKernel<plat::CUDADeviceContext, \
85- ops::func##Functor<int64_t >, void >, \
86- ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float >, \
87- void >, \
88- ops::CompareOpKernel<plat::CUDADeviceContext, \
89- ops::func##Functor<double >, void >);
84+ op_type, \
85+ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int >, void >, \
86+ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t >, void >, \
87+ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float >, void >, \
88+ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double >, void >);
9089
91- REGISTER_CUDA_COMPARE_KERNEL (equal, CudaEqual )
92- REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual )
93- REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan )
94- REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual )
95- REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan )
96- REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual )
90+ REGISTER_CUDA_COMPARE_KERNEL (equal, CudaEqualFunctor )
91+ REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor )
92+ REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor )
93+ REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor )
94+ REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor )
95+ REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor )
9796#undef REGISTER_CUDA_COMPARE_KERNEL
0 commit comments