Skip to content

Commit ef6088a

Browse files
committed
modify expand_v2_op_npu.cc
1 parent 40929fe commit ef6088a

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

paddle/fluid/operators/expand_v2_op_npu.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
21
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
55
You may obtain a copy of the License at
6+
67
http://www.apache.org/licenses/LICENSE-2.0
8+
79
Unless required by applicable law or agreed to in writing, software
810
distributed under the License is distributed on an "AS IS" BASIS,
911
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1012
See the License for the specific language governing permissions and
11-
limitations under the License. */
13+
limitations under the Licnse. */
1214

1315
#include "paddle/fluid/operators/expand_v2_op.h"
1416
#include "paddle/fluid/operators/npu_op_runner.h"
@@ -131,19 +133,22 @@ class ExpandV2NPUGradKernel : public framework::OpKernel<T> {
131133
for (auto i = 0; i < reduce_ndim; ++i) {
132134
axes.push_back(i);
133135
}
134-
Tensor* tmp_dout = const_cast<Tensor*>(dout);
136+
// Tensor* tmp_dout = const_cast<Tensor*>(dout);
137+
Tensor tmp_dout(dout->type());
135138
Tensor reduced_dout(dx->type());
139+
tmp_dout.ShareDataWith(*dout);
136140
if (axes.size() != 0) {
137141
std::vector<int64_t> reduced_dout_dims;
138142
for (auto i = reduce_ndim; i < dout->dims().size(); ++i) {
139143
reduced_dout_dims.push_back(dout->dims()[i]);
140144
}
145+
tmp_dout.Resize(framework::make_ddim(reduced_dout_dims));
141146
reduced_dout.Resize(framework::make_ddim(reduced_dout_dims));
142147
reduced_dout.mutable_data<T>(ctx.GetPlace());
143148
const auto& runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout},
144149
{{"axes", axes}, {"keep_dims", false}});
145150
runner.Run(stream);
146-
tmp_dout = &reduced_dout;
151+
tmp_dout = reduced_dout;
147152
}
148153

149154
// case 2: reduce axis of dout in which dim is 1
@@ -158,11 +163,11 @@ class ExpandV2NPUGradKernel : public framework::OpKernel<T> {
158163
}
159164
}
160165
if (axes.size() != 0) {
161-
const auto& runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dx},
166+
const auto& runner = NpuOpRunner("ReduceSumD", {tmp_dout}, {*dx},
162167
{{"axes", axes}, {"keep_dims", true}});
163168
runner.Run(stream);
164169
} else {
165-
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx);
170+
framework::TensorCopySync(tmp_dout, ctx.GetPlace(), dx);
166171
}
167172
}
168173
};
@@ -181,4 +186,6 @@ REGISTER_OP_NPU_KERNEL(
181186
REGISTER_OP_NPU_KERNEL(
182187
expand_v2_grad,
183188
ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, float>,
189+
ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext,
190+
paddle::platform::float16>,
184191
ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, int>);

0 commit comments

Comments
 (0)