Skip to content

Commit 33f232f

Browse files
committed
[NPU] add NPU ops of stack and unstack, test=develop
1 parent 98c7191 commit 33f232f

File tree

4 files changed

+404
-120
lines changed

4 files changed

+404
-120
lines changed

paddle/fluid/operators/stack_op_npu.cc

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#ifdef PADDLE_WITH_ASCEND_CL
16-
#include <memory>
17-
#include <string>
18-
#include <vector>
19-
20-
#include "paddle/fluid/operators/activation_op.h"
21-
#include "paddle/fluid/operators/npu_op_runner.h"
2215
#include "paddle/fluid/operators/stack_op.h"
23-
#include "paddle/fluid/operators/unsqueeze_op.h"
16+
#include "paddle/fluid/operators/npu_op_runner.h"
2417

2518
namespace paddle {
2619
namespace operators {
@@ -32,64 +25,56 @@ class StackNPUKernel : public framework::OpKernel<T> {
3225
public:
3326
void Compute(const framework::ExecutionContext& ctx) const override {
3427
auto x = ctx.MultiInput<Tensor>("X");
35-
int32_t N = x.size();
28+
auto* y = ctx.Output<Tensor>("Y");
29+
int axis = ctx.Attr<int>("axis");
30+
if (axis < 0) axis += (x[0]->dims().size() + 1);
31+
int num = static_cast<int>(x.size());
3632

37-
PADDLE_ENFORCE_GT(
38-
N, 0, platform::errors::InvalidArgument("number of input Tensor <= 0"));
33+
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
34+
"number of input Tensor <= 0"));
35+
36+
auto stream =
37+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
38+
.stream();
3939

4040
std::vector<paddle::framework::Tensor> x_list;
41-
for (int i = 0; i < N; i++) {
41+
for (int i = 0; i < num; i++) {
4242
x_list.push_back(*x[i]);
4343
}
44+
y->mutable_data<T>(ctx.GetPlace());
4445

45-
int axis = ctx.Attr<int>("axis");
46+
const auto& runner =
47+
NpuOpRunner("Pack", {x_list}, {*y}, {{"axis", axis}, {"N", num}});
48+
runner.Run(stream);
49+
}
50+
};
4651

47-
if (axis < 0) {
48-
axis = axis + x_list[0].dims().size() + 1;
49-
}
50-
auto* out = ctx.Output<Tensor>("Y");
52+
template <typename DeviceContext, typename T>
53+
class StackGradNPUKernel : public framework::OpKernel<T> {
54+
public:
55+
void Compute(const framework::ExecutionContext& ctx) const override {
56+
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
57+
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
58+
int axis = ctx.Attr<int>("axis");
59+
if (axis < 0) axis += dy->dims().size();
60+
int num = dy->dims()[axis];
5161

52-
auto place = ctx.GetPlace();
62+
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
63+
"number of input Tensor <= 0"));
5364

5465
auto stream =
5566
ctx.template device_context<paddle::platform::NPUDeviceContext>()
5667
.stream();
5768

58-
out->mutable_data<T>(place);
59-
60-
if (axis != 0) {
61-
auto x_dim = x_list[0].dims();
62-
std::vector<int> vec_dim_tmp;
63-
vec_dim_tmp.push_back(N);
64-
for (auto i = 0; i < x_dim.size(); ++i) {
65-
vec_dim_tmp.push_back(x_dim[i]);
66-
}
67-
68-
Tensor tmp_stack(out->type());
69-
tmp_stack.Resize(framework::make_ddim(vec_dim_tmp));
70-
tmp_stack.mutable_data<T>(ctx.GetPlace());
71-
72-
const auto& runner =
73-
NpuOpRunner("Pack", {x_list}, {tmp_stack}, {{"axis", 0}, {"N", N}});
74-
runner.Run(stream);
75-
76-
std::vector<int64_t> vec_trans;
77-
for (auto i = 1; i <= x_dim.size(); ++i) {
78-
vec_trans.push_back(i);
79-
if (i == axis) {
80-
vec_trans.push_back(0);
81-
}
82-
}
83-
84-
const auto& runner_trans_final =
85-
NpuOpRunner("TransposeD", {tmp_stack}, {*out}, {{"perm", vec_trans}});
86-
runner_trans_final.Run(stream);
87-
88-
} else {
89-
const auto& runner =
90-
NpuOpRunner("Pack", {x_list}, {*out}, {{"axis", axis}, {"N", N}});
91-
runner.Run(stream);
69+
std::vector<paddle::framework::Tensor> dx_list;
70+
for (int i = 0; i < num; i++) {
71+
dx[i]->mutable_data<T>(ctx.GetPlace());
72+
dx_list.push_back(*dx[i]);
9273
}
74+
75+
const auto& runner =
76+
NpuOpRunner("Unpack", {*dy}, {dx_list}, {{"axis", axis}, {"num", num}});
77+
runner.Run(stream);
9378
}
9479
};
9580

@@ -103,4 +88,8 @@ REGISTER_OP_NPU_KERNEL(
10388
ops::StackNPUKernel<paddle::platform::NPUDeviceContext,
10489
paddle::platform::float16>);
10590

106-
#endif
91+
REGISTER_OP_NPU_KERNEL(
92+
stack_grad,
93+
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
94+
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext,
95+
paddle::platform::float16>);
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/unstack_op.h"
16+
#include "paddle/fluid/operators/npu_op_runner.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename DeviceContext, typename T>
22+
class UnStackNPUKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext &ctx) const override {
25+
auto *dy = ctx.Input<Tensor>("X");
26+
auto dx = ctx.MultiOutput<Tensor>("Y");
27+
int axis = ctx.Attr<int>("axis");
28+
if (axis < 0) axis += dy->dims().size();
29+
int num = dy->dims()[axis];
30+
31+
auto stream =
32+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
33+
.stream();
34+
35+
std::vector<paddle::framework::Tensor> dx_list;
36+
for (int i = 0; i < num; i++) {
37+
dx[i]->mutable_data<T>(ctx.GetPlace());
38+
dx_list.push_back(*dx[i]);
39+
}
40+
41+
const auto &runner =
42+
NpuOpRunner("Unpack", {*dy}, {dx_list}, {{"axis", axis}, {"num", num}});
43+
runner.Run(stream);
44+
}
45+
};
46+
47+
template <typename DeviceContext, typename T>
48+
class UnStackGradNPUKernel : public framework::OpKernel<T> {
49+
public:
50+
void Compute(const framework::ExecutionContext &ctx) const override {
51+
auto x = ctx.MultiInput<Tensor>(framework::GradVarName("Y"));
52+
auto *y = ctx.Output<Tensor>(framework::GradVarName("X"));
53+
int axis = ctx.Attr<int>("axis");
54+
if (axis < 0) axis += (x[0]->dims().size() + 1);
55+
int num = static_cast<int>(x.size());
56+
57+
auto stream =
58+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
59+
.stream();
60+
61+
std::vector<paddle::framework::Tensor> x_list;
62+
for (int i = 0; i < num; i++) {
63+
x_list.push_back(*x[i]);
64+
}
65+
y->mutable_data<T>(ctx.GetPlace());
66+
67+
const auto &runner =
68+
NpuOpRunner("Pack", {x_list}, {*y}, {{"axis", axis}, {"N", num}});
69+
runner.Run(stream);
70+
}
71+
};
72+
73+
} // namespace operators
74+
} // namespace paddle
75+
76+
namespace plat = paddle::platform;
77+
namespace ops = paddle::operators;
78+
79+
REGISTER_OP_NPU_KERNEL(
80+
unstack, ops::UnStackNPUKernel<plat::NPUDeviceContext, float>,
81+
ops::UnStackNPUKernel<plat::NPUDeviceContext, plat::float16>);
82+
83+
REGISTER_OP_NPU_KERNEL(
84+
unstack_grad, ops::UnStackGradNPUKernel<plat::NPUDeviceContext, float>,
85+
ops::UnStackGradNPUKernel<plat::NPUDeviceContext, plat::float16>);

0 commit comments

Comments
 (0)