Skip to content

Commit ef517a5

Browse files
authored
[NPU] Support npu kernel for pad3d op (#34815)
* [NPU] Support npu kernel for pad3d op * fix for comment of zhouwei25 * fix some bugs according to qili93's comments * add support and test for paddings in input * delete VLOG used for debug
1 parent 6bacfb0 commit ef517a5

File tree

2 files changed

+674
-0
lines changed

2 files changed

+674
-0
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the Licnse. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/framework/operator.h"
17+
#include "paddle/fluid/operators/npu_op_runner.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
static inline std::vector<int> GetPaddings(
25+
const framework::ExecutionContext& context) {
26+
std::vector<int> paddings(6);
27+
auto* paddings_t = context.Input<Tensor>("Paddings");
28+
if (paddings_t) {
29+
TensorToVector(*paddings_t, context.device_context(), &paddings);
30+
} else {
31+
auto pads = context.Attr<std::vector<int>>("paddings");
32+
std::copy(pads.begin(), pads.end(), paddings.data());
33+
}
34+
return paddings;
35+
}
36+
37+
template <typename T>
38+
class Pad3dNPUKernel : public framework::OpKernel<T> {
39+
public:
40+
void Compute(const framework::ExecutionContext& context) const override {
41+
auto* x = context.Input<Tensor>("X");
42+
auto in_dims = x->dims();
43+
44+
std::vector<int> pads = GetPaddings(context);
45+
auto mode = context.Attr<std::string>("mode");
46+
float value = context.Attr<float>("value");
47+
auto data_format = context.Attr<std::string>("data_format");
48+
49+
auto* out = context.Output<Tensor>("Out");
50+
51+
PADDLE_ENFORCE_LT(abs(value), 1e-5,
52+
platform::errors::Unimplemented(
53+
"Ascend npu only support constant_values=0 right now,"
54+
"but received constant_value is %f .",
55+
value));
56+
57+
PADDLE_ENFORCE_EQ(mode, "constant",
58+
platform::errors::Unimplemented(
59+
"Ascend npu only support mode=constant right now,"
60+
"but received mode is %s .",
61+
mode));
62+
63+
std::vector<int> paddings(
64+
{0, 0, 0, 0, pads[4], pads[5], pads[2], pads[3], pads[0], pads[1]});
65+
if (data_format == "NCDHW") {
66+
out->Resize({in_dims[0], in_dims[1], in_dims[2] + pads[4] + pads[5],
67+
in_dims[3] + pads[2] + pads[3],
68+
in_dims[4] + pads[0] + pads[1]});
69+
} else {
70+
out->Resize({in_dims[0], in_dims[1] + pads[4] + pads[5],
71+
in_dims[2] + pads[2] + pads[3],
72+
in_dims[3] + pads[0] + pads[1], in_dims[4]});
73+
paddings = {0, 0, pads[4], pads[5], pads[2],
74+
pads[3], pads[0], pads[1], 0, 0};
75+
}
76+
out->mutable_data<T>(context.GetPlace());
77+
78+
NpuOpRunner runner;
79+
runner.SetType("PadV3")
80+
.AddInput(*x)
81+
.AddInput(std::move(paddings))
82+
.AddInput(
83+
std::vector<int>({0})) // npu only support constant_value=0 now
84+
.AddOutput(*out)
85+
.AddAttr("mode", mode);
86+
87+
auto stream =
88+
context.template device_context<paddle::platform::NPUDeviceContext>()
89+
.stream();
90+
runner.Run(stream);
91+
}
92+
};
93+
94+
template <typename T>
95+
class Pad3dGradNPUKernel : public framework::OpKernel<T> {
96+
public:
97+
void Compute(const framework::ExecutionContext& context) const override {
98+
std::vector<int> pads = GetPaddings(context);
99+
auto mode = context.Attr<std::string>("mode");
100+
auto data_format = context.Attr<std::string>("data_format");
101+
102+
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
103+
auto* d_in = context.Output<Tensor>(framework::GradVarName("X"));
104+
auto d_in_dims = d_in->dims();
105+
d_in->mutable_data<T>(context.GetPlace());
106+
107+
const int pad_left = pads[0];
108+
const int pad_top = pads[2];
109+
const int pad_front = pads[4];
110+
111+
auto stream =
112+
context.template device_context<paddle::platform::NPUDeviceContext>()
113+
.stream();
114+
115+
std::vector<int64_t> size(
116+
{d_in_dims[0], d_in_dims[1], d_in_dims[2], d_in_dims[3], d_in_dims[4]});
117+
if (mode == "constant") { // this method can be only used for constant mode
118+
std::vector<int> offsets({0, 0, pad_front, pad_top, pad_left});
119+
if (data_format == "NDHWC") {
120+
offsets = {0, pad_front, pad_top, pad_left, 0};
121+
}
122+
const auto& runner = NpuOpRunner("SliceD", {*d_out}, {*d_in},
123+
{{"offsets", offsets}, {"size", size}});
124+
runner.Run(stream);
125+
}
126+
}
127+
};
128+
129+
} // namespace operators
130+
} // namespace paddle
131+
132+
namespace ops = paddle::operators;
133+
namespace plat = paddle::platform;
134+
135+
REGISTER_OP_NPU_KERNEL(pad3d, ops::Pad3dNPUKernel<plat::float16>,
136+
ops::Pad3dNPUKernel<float>, ops::Pad3dNPUKernel<int>);
137+
138+
REGISTER_OP_NPU_KERNEL(pad3d_grad, ops::Pad3dNPUKernel<plat::float16>,
139+
ops::Pad3dGradNPUKernel<float>);

0 commit comments

Comments
 (0)