@@ -140,19 +140,15 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
140140 // matmul
141141 if (dx) {
142142 // matmul [2, 5] * [12, 5] => [2, 12]
143- Tensor tmp_matmul (y->type ());
144- tmp_matmul.Resize (
145- framework::make_ddim ({dout->dims ()[0 ], y->dims ()[0 ]}));
146- tmp_matmul.mutable_data <T>(ctx.GetPlace ());
143+ dx->mutable_data <T>(ctx.GetPlace ());
144+ auto dx_dims = dx->dims ();
145+ dx->Resize (framework::make_ddim ({dout->dims ()[0 ], y->dims ()[0 ]}));
147146 auto runner_matmul =
148- NpuOpRunner (" MatMul" , {*dout, *y}, {tmp_matmul },
147+ NpuOpRunner (" MatMul" , {*dout, *y}, {*dx },
149148 {{" transpose_x1" , false }, {" transpose_x2" , true }});
150149 runner_matmul.Run (stream);
151150 // reshape [2, 12] => [2, 3, 4]
152- dx->mutable_data (ctx.GetPlace (), x->type ());
153- framework::TensorCopy (
154- tmp_matmul, ctx.GetPlace (),
155- ctx.template device_context <platform::DeviceContext>(), dx);
151+ dx->Resize (dx_dims);
156152 }
157153
158154 if (dy) {
@@ -193,18 +189,15 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
193189
194190 if (dx) {
195191 // tmp_dout * y [6,5] * [4,5] => [6, 4]
196- Tensor tmp_matmul (y-> type ());
197- tmp_matmul. Resize ( framework::make_ddim ({dout_first_dim, y ->dims ()[ 0 ]}) );
198- tmp_matmul. mutable_data <T>(ctx. GetPlace ( ));
192+ dx-> mutable_data <T>(ctx. GetPlace ());
193+ auto dx_dims = dx ->dims ();
194+ dx-> Resize ( framework::make_ddim ({dout_first_dim, y-> dims ()[ 0 ]} ));
199195 auto runner_matmul =
200- NpuOpRunner (" MatMul" , {tmp_dout, *y}, {tmp_matmul },
196+ NpuOpRunner (" MatMul" , {tmp_dout, *y}, {*dx },
201197 {{" transpose_x1" , false }, {" transpose_x2" , true }});
202198 runner_matmul.Run (stream);
203- // reshape [6,4] => [2, 3, 4]
204- dx->mutable_data (ctx.GetPlace (), x->type ());
205- framework::TensorCopy (
206- tmp_matmul, ctx.GetPlace (),
207- ctx.template device_context <platform::DeviceContext>(), dx);
199+ // reshape [2, 12] => [2, 3, 4]
200+ dx->Resize (dx_dims);
208201 }
209202 if (dy) {
210203 // flatten x.shape [2,3,4] => [6, 4]
0 commit comments