Skip to content

Commit 1de6daf

Browse files
authored
[NPU] fix shape of dx in mul_grad (#31675)
* fix shape of dx * refine code
1 parent 3dd992e commit 1de6daf

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

paddle/fluid/operators/mul_op_npu.cc

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,15 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
140140
// matmul
141141
if (dx) {
142142
// matmul [2, 5] * [12, 5] => [2, 12]
143-
Tensor tmp_matmul(y->type());
144-
tmp_matmul.Resize(
145-
framework::make_ddim({dout->dims()[0], y->dims()[0]}));
146-
tmp_matmul.mutable_data<T>(ctx.GetPlace());
143+
dx->mutable_data<T>(ctx.GetPlace());
144+
auto dx_dims = dx->dims();
145+
dx->Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]}));
147146
auto runner_matmul =
148-
NpuOpRunner("MatMul", {*dout, *y}, {tmp_matmul},
147+
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
149148
{{"transpose_x1", false}, {"transpose_x2", true}});
150149
runner_matmul.Run(stream);
151150
// reshape [2, 12] => [2, 3, 4]
152-
dx->mutable_data(ctx.GetPlace(), x->type());
153-
framework::TensorCopy(
154-
tmp_matmul, ctx.GetPlace(),
155-
ctx.template device_context<platform::DeviceContext>(), dx);
151+
dx->Resize(dx_dims);
156152
}
157153

158154
if (dy) {
@@ -193,18 +189,15 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
193189

194190
if (dx) {
195191
// tmp_dout * y [6,5] * [4,5] => [6, 4]
196-
Tensor tmp_matmul(y->type());
197-
tmp_matmul.Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
198-
tmp_matmul.mutable_data<T>(ctx.GetPlace());
192+
dx->mutable_data<T>(ctx.GetPlace());
193+
auto dx_dims = dx->dims();
194+
dx->Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
199195
auto runner_matmul =
200-
NpuOpRunner("MatMul", {tmp_dout, *y}, {tmp_matmul},
196+
NpuOpRunner("MatMul", {tmp_dout, *y}, {*dx},
201197
{{"transpose_x1", false}, {"transpose_x2", true}});
202198
runner_matmul.Run(stream);
203-
// reshape [6,4] => [2, 3, 4]
204-
dx->mutable_data(ctx.GetPlace(), x->type());
205-
framework::TensorCopy(
206-
tmp_matmul, ctx.GetPlace(),
207-
ctx.template device_context<platform::DeviceContext>(), dx);
199+
// reshape [2, 12] => [2, 3, 4]
200+
dx->Resize(dx_dims);
208201
}
209202
if (dy) {
210203
// flatten x.shape [2,3,4] => [6, 4]

0 commit comments

Comments
 (0)