@@ -18,6 +18,33 @@ limitations under the License. */
1818namespace paddle {
1919namespace operators {
2020
21+ inline std::vector<int > get_expand_shape_npu (
22+ const framework::ExecutionContext& ctx) {
23+ std::vector<int > vec_expand_shape;
24+ auto list_expand_shapes_tensor =
25+ ctx.MultiInput <framework::Tensor>(" expand_shapes_tensor" );
26+ if (ctx.HasInput (" Shape" )) {
27+ auto * shape_tensor = ctx.Input <framework::LoDTensor>(" Shape" );
28+ std::vector<int > out_data;
29+ TensorToVector (*shape_tensor, ctx.device_context (), &out_data);
30+ for (int i = 0 ; i < static_cast <int >(out_data.size ()); ++i) {
31+ vec_expand_shape.push_back (out_data[i]);
32+ }
33+ return vec_expand_shape;
34+ } else if (list_expand_shapes_tensor.size () > 0 ) {
35+ // get tensor from
36+ for (size_t i = 0 ; i < list_expand_shapes_tensor.size (); ++i) {
37+ auto tensor = list_expand_shapes_tensor[i];
38+ std::vector<int > out_data;
39+ TensorToVector (*tensor, ctx.device_context (), &out_data);
40+ vec_expand_shape.push_back (out_data[0 ]);
41+ }
42+ return vec_expand_shape;
43+ } else {
44+ return ctx.Attr <std::vector<int >>(" shape" );
45+ }
46+ }
47+
2148using Tensor = framework::Tensor;
2249template <typename DeviceContext, typename T>
2350class ExpandV2NPUKernel : public framework ::OpKernel<T> {
@@ -26,27 +53,7 @@ class ExpandV2NPUKernel : public framework::OpKernel<T> {
2653 auto * X = ctx.Input <framework::Tensor>(" X" );
2754 auto * Out = ctx.Output <framework::Tensor>(" Out" );
2855
29- std::vector<int > expand_shape;
30- auto list_expand_shapes_tensor =
31- ctx.MultiInput <framework::Tensor>(" expand_shapes_tensor" );
32- if (ctx.HasInput (" Shape" )) {
33- auto * shape_tensor = ctx.Input <framework::LoDTensor>(" Shape" );
34- std::vector<int > out_data;
35- TensorToVector (*shape_tensor, ctx.device_context (), &out_data);
36- for (int i = 0 ; i < static_cast <int >(out_data.size ()); ++i) {
37- expand_shape.push_back (out_data[i]);
38- }
39- } else if (list_expand_shapes_tensor.size () > 0 ) {
40- // get tensor from
41- for (size_t i = 0 ; i < list_expand_shapes_tensor.size (); ++i) {
42- auto tensor = list_expand_shapes_tensor[i];
43- std::vector<int > out_data;
44- TensorToVector (*tensor, ctx.device_context (), &out_data);
45- expand_shape.push_back (out_data[0 ]);
46- }
47- } else {
48- expand_shape = ctx.Attr <std::vector<int >>(" shape" );
49- }
56+ std::vector<int > expand_shape = get_expand_shape_npu (ctx);
5057
5158 framework::NPUAttributeMap attr_input = {{" shape" , expand_shape}};
5259
@@ -97,6 +104,62 @@ class ExpandV2NPUKernel : public framework::OpKernel<T> {
97104 runner.Run (stream);
98105 }
99106};
107+
108+ template <typename DeviceContext, typename T>
109+ class ExpandV2NPUGradKernel : public framework ::OpKernel<T> {
110+ public:
111+ void Compute (const framework::ExecutionContext& ctx) const override {
112+ auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
113+ auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
114+ dx->mutable_data <T>(ctx.GetPlace ());
115+
116+ auto stream =
117+ ctx.template device_context <paddle::platform::NPUDeviceContext>()
118+ .stream ();
119+
120+ // case 1: reduce dout dims to dx dims
121+ // For example: [2, 120] --> [120]
122+ auto reduce_ndim = dout->dims ().size () - dx->dims ().size ();
123+ std::vector<int > axes;
124+ for (auto i = 0 ; i < reduce_ndim; ++i) {
125+ axes.push_back (i);
126+ }
127+ Tensor* tmp_dout = const_cast <Tensor*>(dout);
128+ Tensor reduced_dout (dx->type ());
129+ if (axes.size () != 0 ) {
130+ std::vector<int64_t > reduced_dout_dims;
131+ for (auto i = reduce_ndim; i < dout->dims ().size (); ++i) {
132+ reduced_dout_dims.push_back (dout->dims ()[i]);
133+ }
134+ reduced_dout.Resize (framework::make_ddim (reduced_dout_dims));
135+ reduced_dout.mutable_data <T>(ctx.GetPlace ());
136+ const auto & runner = NpuOpRunner (" ReduceSumD" , {*dout}, {reduced_dout},
137+ {{" axes" , axes}, {" keep_dims" , false }});
138+ runner.Run (stream);
139+ tmp_dout = &reduced_dout;
140+ }
141+
142+ // case 2: reduce axis of dout in which dim is 1
143+ // For example: [12, 140] --> [1, 140]
144+
145+ // case 3: copy dout to dx when shape is totally same, and dim in dx != 1
146+ // For example: [2, 10, 5] --> [2, 10, 5]
147+ axes.clear ();
148+ for (auto i = 0 ; i < dx->dims ().size (); ++i) {
149+ if (dx->dims ()[i] == 1 ) {
150+ axes.push_back (i);
151+ }
152+ }
153+ if (axes.size () != 0 ) {
154+ const auto & runner = NpuOpRunner (" ReduceSumD" , {*tmp_dout}, {*dx},
155+ {{" axes" , axes}, {" keep_dims" , true }});
156+ runner.Run (stream);
157+ } else {
158+ framework::TensorCopySync (*tmp_dout, ctx.GetPlace (), dx);
159+ }
160+ }
161+ };
162+
100163} // namespace operators
101164} // namespace paddle
102165
@@ -107,3 +170,8 @@ REGISTER_OP_NPU_KERNEL(
107170 ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext,
108171 paddle::platform::float16>,
109172 ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, int >);
173+
174+ REGISTER_OP_NPU_KERNEL (
175+ expand_v2_grad,
176+ ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, float >,
177+ ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, int >);
0 commit comments