Skip to content

Commit 98d2531

Browse files
authored
change TensorCopy to ShareDataWith in matmul_grad op (#33755)
1 parent 3946afc commit 98d2531

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

paddle/fluid/operators/matmul_v2_op_npu.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,13 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
141141
if ((x->dims().size() == 3) && (dout->dims().size() == 3) &&
142142
(dy->dims().size() == 2)) {
143143
framework::Tensor dout_;
144-
TensorCopy(*dout, ctx.GetPlace(), &dout_);
145-
ctx.template device_context<paddle::platform::NPUDeviceContext>()
146-
.Wait();
144+
dout_.ShareDataWith(*dout);
147145
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
148146
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
149147
dout_.Resize(framework::make_ddim(vec_dim_v));
150148

151149
framework::Tensor x_;
152-
TensorCopy(*x, ctx.GetPlace(), &x_);
153-
ctx.template device_context<paddle::platform::NPUDeviceContext>()
154-
.Wait();
150+
x_.ShareDataWith(*x);
155151
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
156152
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
157153
vec_dim_x[2]};

0 commit comments

Comments
 (0)