1-
21/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
33Licensed under the Apache License, Version 2.0 (the "License");
44you may not use this file except in compliance with the License.
55You may obtain a copy of the License at
6+
67 http://www.apache.org/licenses/LICENSE-2.0
8+
79Unless required by applicable law or agreed to in writing, software
810distributed under the License is distributed on an "AS IS" BASIS,
911WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1012See the License for the specific language governing permissions and
11- limitations under the License . */
13+ limitations under the Licnse . */
1214
1315#include " paddle/fluid/operators/expand_v2_op.h"
1416#include " paddle/fluid/operators/npu_op_runner.h"
@@ -131,19 +133,22 @@ class ExpandV2NPUGradKernel : public framework::OpKernel<T> {
131133 for (auto i = 0 ; i < reduce_ndim; ++i) {
132134 axes.push_back (i);
133135 }
134- Tensor* tmp_dout = const_cast <Tensor*>(dout);
136+ // Tensor* tmp_dout = const_cast<Tensor*>(dout);
137+ Tensor tmp_dout (dout->type ());
135138 Tensor reduced_dout (dx->type ());
139+ tmp_dout.ShareDataWith (*dout);
136140 if (axes.size () != 0 ) {
137141 std::vector<int64_t > reduced_dout_dims;
138142 for (auto i = reduce_ndim; i < dout->dims ().size (); ++i) {
139143 reduced_dout_dims.push_back (dout->dims ()[i]);
140144 }
145+ tmp_dout.Resize (framework::make_ddim (reduced_dout_dims));
141146 reduced_dout.Resize (framework::make_ddim (reduced_dout_dims));
142147 reduced_dout.mutable_data <T>(ctx.GetPlace ());
143148 const auto & runner = NpuOpRunner (" ReduceSumD" , {*dout}, {reduced_dout},
144149 {{" axes" , axes}, {" keep_dims" , false }});
145150 runner.Run (stream);
146- tmp_dout = & reduced_dout;
151+ tmp_dout = reduced_dout;
147152 }
148153
149154 // case 2: reduce axis of dout in which dim is 1
@@ -158,11 +163,11 @@ class ExpandV2NPUGradKernel : public framework::OpKernel<T> {
158163 }
159164 }
160165 if (axes.size () != 0 ) {
161- const auto & runner = NpuOpRunner (" ReduceSumD" , {* tmp_dout}, {*dx},
166+ const auto & runner = NpuOpRunner (" ReduceSumD" , {tmp_dout}, {*dx},
162167 {{" axes" , axes}, {" keep_dims" , true }});
163168 runner.Run (stream);
164169 } else {
165- framework::TensorCopySync (* tmp_dout, ctx.GetPlace (), dx);
170+ framework::TensorCopySync (tmp_dout, ctx.GetPlace (), dx);
166171 }
167172 }
168173};
@@ -181,4 +186,6 @@ REGISTER_OP_NPU_KERNEL(
181186REGISTER_OP_NPU_KERNEL (
182187 expand_v2_grad,
183188 ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, float >,
189+ ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext,
190+ paddle::platform::float16>,
184191 ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, int >);
0 commit comments