@@ -12,15 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #ifdef PADDLE_WITH_ASCEND_CL
16- #include < memory>
17- #include < string>
18- #include < vector>
19-
20- #include " paddle/fluid/operators/activation_op.h"
21- #include " paddle/fluid/operators/npu_op_runner.h"
2215#include " paddle/fluid/operators/stack_op.h"
23- #include " paddle/fluid/operators/unsqueeze_op .h"
16+ #include " paddle/fluid/operators/npu_op_runner .h"
2417
2518namespace paddle {
2619namespace operators {
@@ -32,64 +25,56 @@ class StackNPUKernel : public framework::OpKernel<T> {
3225 public:
3326 void Compute (const framework::ExecutionContext& ctx) const override {
3427 auto x = ctx.MultiInput <Tensor>(" X" );
35- int32_t N = x.size ();
28+ auto * y = ctx.Output <Tensor>(" Y" );
29+ int axis = ctx.Attr <int >(" axis" );
30+ if (axis < 0 ) axis += (x[0 ]->dims ().size () + 1 );
31+ int num = static_cast <int >(x.size ());
3632
37- PADDLE_ENFORCE_GT (
38- N, 0 , platform::errors::InvalidArgument (" number of input Tensor <= 0" ));
33+ PADDLE_ENFORCE_GT (num, 0 , platform::errors::InvalidArgument (
34+ " number of input Tensor <= 0" ));
35+
36+ auto stream =
37+ ctx.template device_context <paddle::platform::NPUDeviceContext>()
38+ .stream ();
3939
4040 std::vector<paddle::framework::Tensor> x_list;
41- for (int i = 0 ; i < N ; i++) {
41+ for (int i = 0 ; i < num ; i++) {
4242 x_list.push_back (*x[i]);
4343 }
44+ y->mutable_data <T>(ctx.GetPlace ());
4445
45- int axis = ctx.Attr <int >(" axis" );
46+ const auto & runner =
47+ NpuOpRunner (" Pack" , {x_list}, {*y}, {{" axis" , axis}, {" N" , num}});
48+ runner.Run (stream);
49+ }
50+ };
4651
47- if (axis < 0 ) {
48- axis = axis + x_list[0 ].dims ().size () + 1 ;
49- }
50- auto * out = ctx.Output <Tensor>(" Y" );
52+ template <typename DeviceContext, typename T>
53+ class StackGradNPUKernel : public framework ::OpKernel<T> {
54+ public:
55+ void Compute (const framework::ExecutionContext& ctx) const override {
56+ auto * dy = ctx.Input <Tensor>(framework::GradVarName (" Y" ));
57+ auto dx = ctx.MultiOutput <Tensor>(framework::GradVarName (" X" ));
58+ int axis = ctx.Attr <int >(" axis" );
59+ if (axis < 0 ) axis += dy->dims ().size ();
60+ int num = dy->dims ()[axis];
5161
52- auto place = ctx.GetPlace ();
62+ PADDLE_ENFORCE_GT (num, 0 , platform::errors::InvalidArgument (
63+ " number of input Tensor <= 0" ));
5364
5465 auto stream =
5566 ctx.template device_context <paddle::platform::NPUDeviceContext>()
5667 .stream ();
5768
58- out->mutable_data <T>(place);
59-
60- if (axis != 0 ) {
61- auto x_dim = x_list[0 ].dims ();
62- std::vector<int > vec_dim_tmp;
63- vec_dim_tmp.push_back (N);
64- for (auto i = 0 ; i < x_dim.size (); ++i) {
65- vec_dim_tmp.push_back (x_dim[i]);
66- }
67-
68- Tensor tmp_stack (out->type ());
69- tmp_stack.Resize (framework::make_ddim (vec_dim_tmp));
70- tmp_stack.mutable_data <T>(ctx.GetPlace ());
71-
72- const auto & runner =
73- NpuOpRunner (" Pack" , {x_list}, {tmp_stack}, {{" axis" , 0 }, {" N" , N}});
74- runner.Run (stream);
75-
76- std::vector<int64_t > vec_trans;
77- for (auto i = 1 ; i <= x_dim.size (); ++i) {
78- vec_trans.push_back (i);
79- if (i == axis) {
80- vec_trans.push_back (0 );
81- }
82- }
83-
84- const auto & runner_trans_final =
85- NpuOpRunner (" TransposeD" , {tmp_stack}, {*out}, {{" perm" , vec_trans}});
86- runner_trans_final.Run (stream);
87-
88- } else {
89- const auto & runner =
90- NpuOpRunner (" Pack" , {x_list}, {*out}, {{" axis" , axis}, {" N" , N}});
91- runner.Run (stream);
69+ std::vector<paddle::framework::Tensor> dx_list;
70+ for (int i = 0 ; i < num; i++) {
71+ dx[i]->mutable_data <T>(ctx.GetPlace ());
72+ dx_list.push_back (*dx[i]);
9273 }
74+
75+ const auto & runner =
76+ NpuOpRunner (" Unpack" , {*dy}, {dx_list}, {{" axis" , axis}, {" num" , num}});
77+ runner.Run (stream);
9378 }
9479};
9580
@@ -103,4 +88,8 @@ REGISTER_OP_NPU_KERNEL(
10388 ops::StackNPUKernel<paddle::platform::NPUDeviceContext,
10489 paddle::platform::float16>);
10590
106- #endif
91+ REGISTER_OP_NPU_KERNEL (
92+ stack_grad,
93+ ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext, float >,
94+ ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext,
95+ paddle::platform::float16>);
0 commit comments