File tree Expand file tree Collapse file tree 1 file changed +2
-6
lines changed Expand file tree Collapse file tree 1 file changed +2
-6
lines changed Original file line number Diff line number Diff line change @@ -141,17 +141,13 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
141141 if ((x->dims ().size () == 3 ) && (dout->dims ().size () == 3 ) &&
142142 (dy->dims ().size () == 2 )) {
143143 framework::Tensor dout_;
144- TensorCopy (*dout, ctx.GetPlace (), &dout_);
145- ctx.template device_context <paddle::platform::NPUDeviceContext>()
146- .Wait ();
144+ dout_.ShareDataWith (*dout);
147145 std::vector<int > vec_dim = framework::vectorize<int >(dout_.dims ());
148146 std::vector<int > vec_dim_v{vec_dim[0 ] * vec_dim[1 ], vec_dim[2 ]};
149147 dout_.Resize (framework::make_ddim (vec_dim_v));
150148
151149 framework::Tensor x_;
152- TensorCopy (*x, ctx.GetPlace (), &x_);
153- ctx.template device_context <paddle::platform::NPUDeviceContext>()
154- .Wait ();
150+ x_.ShareDataWith (*x);
155151 std::vector<int > vec_dim_x = framework::vectorize<int >(x_.dims ());
156152 std::vector<int > vec_dim_x_v{vec_dim_x[0 ] * vec_dim_x[1 ],
157153 vec_dim_x[2 ]};
You can’t perform that action at this time.
0 commit comments