Skip to content

Commit 259ecb2

Browse files
committed
delete print codes
1 parent 9d5db8d commit 259ecb2

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

paddle/fluid/operators/controlflow/compare_op.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,13 @@ template <typename DeviceContext, typename Functor>
6060
class CompareOpCudaKernel
6161
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
6262
public:
63-
using InT = typename Functor::ELEMENT_TYPE;
64-
using OutT = bool;
63+
public:
64+
using T = typename Functor::ELEMENT_TYPE;
6565
void Compute(const framework::ExecutionContext& ctx) const override {
66-
auto* x = ctx.Input<framework::Tensor>("X");
67-
auto* y = ctx.Input<framework::Tensor>("Y");
68-
auto* z = ctx.Output<framework::Tensor>("Out");
69-
z->mutable_data<OutT>(ctx.GetPlace());
70-
66+
auto* x = ctx.Input<framework::LoDTensor>("X");
67+
auto* y = ctx.Input<framework::LoDTensor>("Y");
68+
auto* z = ctx.Output<framework::LoDTensor>("Out");
69+
z->mutable_data<T>(ctx.GetPlace());
7170
int axis = ctx.Attr<int>("axis");
7271
axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis;
7372
auto functor = Functor();
@@ -77,7 +76,7 @@ class CompareOpCudaKernel
7776
const auto& cuda_ctx =
7877
ctx.template device_context<platform::CUDADeviceContext>();
7978

80-
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
79+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
8180
cuda_ctx, ins, &outs, axis, functor);
8281
}
8382
};

0 commit comments

Comments
 (0)