Skip to content

Commit f6b518c

Browse files
committed
Fix elementwise_mul_op.cc
1 parent cb28428 commit f6b518c

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

paddle/operators/elementwise_mul_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel {
3131
auto y_dim = ctx.Input<Tensor>("Y")->dims();
3232
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
3333
"Rank of first input must >= rank of second input.")
34-
ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
34+
ctx.Output<framework::LoDTensor>("Out")->Resize(x_dim);
3535
}
3636
};
3737

@@ -80,8 +80,10 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
8080
auto x_dims = ctx.Input<Tensor>("X")->dims();
8181
auto y_dims = ctx.Input<Tensor>("Y")->dims();
8282
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
83-
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
84-
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
83+
auto *x_grad =
84+
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
85+
auto *y_grad =
86+
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
8587

8688
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
8789
"Rank of first input must >= rank of second input.")

paddle/pybind/pybind.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,6 @@ All parameter, weight, gradient are variables in Paddle.
176176
.def("set_int",
177177
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
178178
.def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
179-
// .def("get_tensor",
180-
// [](Variable &self) -> Tensor * { return
181-
// self.GetMutable<Tensor>(); },
182-
// py::return_value_policy::reference)
183179
.def("get_tensor",
184180
[](Variable &self) -> LoDTensor * {
185181
return self.GetMutable<LoDTensor>();

0 commit comments

Comments
 (0)