@@ -60,14 +60,13 @@ template <typename DeviceContext, typename Functor>
6060class CompareOpCudaKernel
6161 : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
6262 public:
63- using InT = typename Functor::ELEMENT_TYPE;
64- using OutT = bool ;
63+ public:
64+ using T = typename Functor::ELEMENT_TYPE ;
6565 void Compute (const framework::ExecutionContext& ctx) const override {
66- auto * x = ctx.Input <framework::Tensor>(" X" );
67- auto * y = ctx.Input <framework::Tensor>(" Y" );
68- auto * z = ctx.Output <framework::Tensor>(" Out" );
69- z->mutable_data <OutT>(ctx.GetPlace ());
70-
66+ auto * x = ctx.Input <framework::LoDTensor>(" X" );
67+ auto * y = ctx.Input <framework::LoDTensor>(" Y" );
68+ auto * z = ctx.Output <framework::LoDTensor>(" Out" );
69+ z->mutable_data <T>(ctx.GetPlace ());
7170 int axis = ctx.Attr <int >(" axis" );
7271 axis = axis == -1 ? std::abs (x->dims ().size () - y->dims ().size ()) : axis;
7372 auto functor = Functor ();
@@ -77,7 +76,7 @@ class CompareOpCudaKernel
7776 const auto & cuda_ctx =
7877 ctx.template device_context <platform::CUDADeviceContext>();
7978
80- LaunchElementwiseCudaKernel<ElementwiseType::kBinary , InT, OutT >(
79+ LaunchElementwiseCudaKernel<ElementwiseType::kBinary , T, bool >(
8180 cuda_ctx, ins, &outs, axis, functor);
8281 }
8382};
0 commit comments