Skip to content

Commit ee8d5cd

Browse files
committed
[Host] add pad2d; test=develop
1 parent 104921f commit ee8d5cd

5 files changed

Lines changed: 410 additions & 5 deletions

File tree

lite/backends/host/math/pad2d.h

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include <string>
18+
#include <vector>
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace host {
23+
namespace math {
24+
25+
template <class T>
26+
void Pad2DReflectNCHW(const T* in_data,
27+
const int num,
28+
const int channels,
29+
const int in_height,
30+
const int in_width,
31+
const int out_height,
32+
const int out_width,
33+
const int pad_top,
34+
const int pad_left,
35+
T* out_data) {
36+
for (int n = 0; n < num; ++n) {
37+
for (int c = 0; c < channels; ++c) {
38+
for (int out_h = 0; out_h < out_height; ++out_h) {
39+
for (int out_w = 0; out_w < out_width; ++out_w) {
40+
int in_h = out_h - pad_top;
41+
int in_w = out_w - pad_left;
42+
in_h = std::max(in_h, -in_h); // reflect by 0
43+
in_h =
44+
std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height
45+
in_w = std::max(in_w, -in_w); // reflect by 0
46+
in_w =
47+
std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width
48+
out_data[out_h * out_width + out_w] = in_data[in_h * in_width + in_w];
49+
}
50+
}
51+
in_data += in_height * in_width;
52+
out_data += out_height * out_width;
53+
}
54+
}
55+
}
56+
57+
template <typename T>
58+
void Pad2DEdgeNCHW(const T* in_data,
59+
const int num,
60+
const int channels,
61+
const int in_height,
62+
const int in_width,
63+
const int out_height,
64+
const int out_width,
65+
const int pad_top,
66+
const int pad_left,
67+
T* out_data) {
68+
for (int n = 0; n < num; ++n) {
69+
for (int c = 0; c < channels; ++c) {
70+
for (int out_h = 0; out_h < out_height; ++out_h) {
71+
for (int out_w = 0; out_w < out_width; ++out_w) {
72+
int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0));
73+
int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0));
74+
out_data[out_h * out_width + out_w] = in_data[in_h * in_width + in_w];
75+
}
76+
}
77+
in_data += in_height * in_width;
78+
out_data += out_height * out_width;
79+
}
80+
}
81+
}
82+
83+
template <typename T>
84+
void Pad2DConstNCHW(const T* in_data,
85+
const int num,
86+
const int channels,
87+
const int in_height,
88+
const int in_width,
89+
const int out_height,
90+
const int out_width,
91+
const int pad_top,
92+
const int pad_left,
93+
T value,
94+
T* out_data) {
95+
for (int n = 0; n < num; ++n) {
96+
for (int c = 0; c < channels; ++c) {
97+
for (int out_h = 0; out_h < out_height; ++out_h) {
98+
for (int out_w = 0; out_w < out_width; ++out_w) {
99+
int in_h = out_h - pad_top;
100+
int in_w = out_w - pad_left;
101+
out_data[out_h * out_width + out_w] =
102+
(in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width)
103+
? value
104+
: in_data[in_h * in_width + in_w];
105+
}
106+
}
107+
in_data += in_height * in_width;
108+
out_data += out_height * out_width;
109+
}
110+
}
111+
}
112+
113+
template <typename T>
114+
void Pad2DReflectNHWC(const T* in_data,
115+
const int num,
116+
const int channels,
117+
const int in_height,
118+
const int in_width,
119+
const int out_height,
120+
const int out_width,
121+
const int pad_top,
122+
const int pad_left,
123+
T* out_data) {
124+
for (int n = 0; n < num; ++n) {
125+
for (int out_h = 0; out_h < out_height; ++out_h) {
126+
for (int out_w = 0; out_w < out_width; ++out_w) {
127+
const int out_index = (out_h * out_width + out_w) * channels;
128+
int in_h = out_h - pad_top;
129+
int in_w = out_w - pad_left;
130+
in_h = std::max(in_h, -in_h);
131+
in_h = std::min(in_h, 2 * in_height - in_h - 2);
132+
in_w = std::max(in_w, -in_w);
133+
in_w = std::min(in_w, 2 * in_width - in_w - 2);
134+
const int in_index = (in_h * in_width + in_w) * channels;
135+
136+
for (int c = 0; c < channels; ++c) {
137+
out_data[out_index + c] = in_data[in_index + c];
138+
}
139+
}
140+
}
141+
in_data += in_height * in_width * channels;
142+
out_data += out_height * out_width * channels;
143+
}
144+
}
145+
146+
template <typename T>
147+
void Pad2DEdgeNHWC(const T* in_data,
148+
const int num,
149+
const int channels,
150+
const int in_height,
151+
const int in_width,
152+
const int out_height,
153+
const int out_width,
154+
const int pad_top,
155+
const int pad_left,
156+
T* out_data) {
157+
for (int n = 0; n < num; ++n) {
158+
for (int out_h = 0; out_h < out_height; ++out_h) {
159+
for (int out_w = 0; out_w < out_width; ++out_w) {
160+
const int out_index = (out_h * out_width + out_w) * channels;
161+
int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0));
162+
int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0));
163+
const int in_index = (in_h * in_width + in_w) * channels;
164+
for (int c = 0; c < channels; ++c) {
165+
out_data[out_index + c] = in_data[in_index + c];
166+
}
167+
}
168+
}
169+
in_data += in_height * in_width * channels;
170+
out_data += out_height * out_width * channels;
171+
}
172+
}
173+
174+
template <typename T>
175+
void Pad2DConstNHWC(const T* in_data,
176+
const int num,
177+
const int channels,
178+
const int in_height,
179+
const int in_width,
180+
const int out_height,
181+
const int out_width,
182+
const int pad_top,
183+
const int pad_left,
184+
T value,
185+
T* out_data) {
186+
for (int n = 0; n < num; ++n) {
187+
for (int out_h = 0; out_h < out_height; ++out_h) {
188+
for (int out_w = 0; out_w < out_width; ++out_w) {
189+
int in_h = out_h - pad_top;
190+
int in_w = out_w - pad_left;
191+
const int out_index = (out_h * out_width + out_w) * channels;
192+
if (in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width) {
193+
for (int c = 0; c < channels; ++c) {
194+
out_data[out_index + c] = value;
195+
}
196+
} else {
197+
const int in_index = (in_h * in_width + in_w) * channels;
198+
for (int c = 0; c < channels; ++c) {
199+
out_data[out_index + c] = in_data[in_index + c];
200+
}
201+
}
202+
}
203+
}
204+
in_data += in_height * in_width * channels;
205+
out_data += out_height * out_width * channels;
206+
}
207+
}
208+
209+
} // namespace math
210+
} // namespace host
211+
} // namespace lite
212+
} // namespace paddle

lite/kernels/host/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ add_kernel(gather_compute_host Host extra SRCS gather_compute.cc DEPS ${lite_ker
5656
add_kernel(gather_nd_compute_host Host extra SRCS gather_nd_compute.cc DEPS ${lite_kernel_deps})
5757
add_kernel(gather_tree_compute_host Host extra SRCS gather_tree_compute.cc DEPS ${lite_kernel_deps})
5858
add_kernel(increment_compute_host Host extra SRCS increment_compute.cc DEPS ${lite_kernel_deps})
59+
add_kernel(pad2d_compute_host Host extra SRCS pad2d_compute.cc DEPS ${lite_kernel_deps})
5960
add_kernel(pad3d_compute_host Host extra SRCS pad3d_compute.cc DEPS ${lite_kernel_deps} math_host)
6061
add_kernel(select_input_compute_host Host extra SRCS select_input_compute.cc DEPS ${lite_kernel_deps} math_host)
6162
add_kernel(tensor_array_to_tensor_compute_host Host extra SRCS tensor_array_to_tensor_compute.cc DEPS ${lite_kernel_deps} math_host)

lite/kernels/host/pad2d_compute.cc

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/pad2d_compute.h"
16+
#include <algorithm>
17+
#include <string>
18+
#include <vector>
19+
#include "lite/backends/host/math/pad2d.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace kernels {
24+
namespace host {
25+
26+
template <class T>
27+
void Pad2dCompute<T>::Run() {
28+
auto& param = this->template Param<param_t>();
29+
auto pads = param.paddings;
30+
auto mode = param.mode;
31+
auto data_format = param.data_format;
32+
T value = static_cast<T>(param.pad_value);
33+
34+
auto* x = param.X;
35+
auto in_dims = x->dims();
36+
auto* in_data = x->template data<T>();
37+
38+
auto* out = param.Out;
39+
if (data_format == "NCHW") {
40+
out->Resize({in_dims[0],
41+
in_dims[1],
42+
in_dims[2] + pads[0] + pads[1],
43+
in_dims[3] + pads[2] + pads[3]});
44+
} else {
45+
out->Resize({in_dims[0],
46+
in_dims[1] + pads[0] + pads[1],
47+
in_dims[2] + pads[2] + pads[3],
48+
in_dims[3]});
49+
}
50+
auto out_dims = out->dims();
51+
T* out_data = out->template mutable_data<T>();
52+
53+
const int pad_top = pads[0];
54+
const int pad_left = pads[2];
55+
const int num = in_dims[0];
56+
if (data_format == "NCHW") {
57+
const int channels = in_dims[1];
58+
const int in_height = in_dims[2];
59+
const int in_width = in_dims[3];
60+
const int out_height = out_dims[2];
61+
const int out_width = out_dims[3];
62+
if (mode == "reflect") {
63+
lite::host::math::Pad2DReflectNCHW(in_data,
64+
num,
65+
channels,
66+
in_height,
67+
in_width,
68+
out_height,
69+
out_width,
70+
pad_top,
71+
pad_left,
72+
out_data);
73+
} else if (mode == "edge") {
74+
lite::host::math::Pad2DEdgeNCHW(in_data,
75+
num,
76+
channels,
77+
in_height,
78+
in_width,
79+
out_height,
80+
out_width,
81+
pad_top,
82+
pad_left,
83+
out_data);
84+
} else {
85+
lite::host::math::Pad2DConstNCHW(in_data,
86+
num,
87+
channels,
88+
in_height,
89+
in_width,
90+
out_height,
91+
out_width,
92+
pad_top,
93+
pad_left,
94+
value,
95+
out_data);
96+
}
97+
} else {
98+
const int channels = in_dims[3];
99+
const int in_height = in_dims[1];
100+
const int in_width = in_dims[2];
101+
const int out_height = out_dims[1];
102+
const int out_width = out_dims[2];
103+
if (mode == "reflect") {
104+
lite::host::math::Pad2DReflectNHWC(in_data,
105+
num,
106+
channels,
107+
in_height,
108+
in_width,
109+
out_height,
110+
out_width,
111+
pad_top,
112+
pad_left,
113+
out_data);
114+
} else if (mode == "edge") {
115+
lite::host::math::Pad2DEdgeNHWC(in_data,
116+
num,
117+
channels,
118+
in_height,
119+
in_width,
120+
out_height,
121+
out_width,
122+
pad_top,
123+
pad_left,
124+
out_data);
125+
} else {
126+
lite::host::math::Pad2DConstNHWC(in_data,
127+
num,
128+
channels,
129+
in_height,
130+
in_width,
131+
out_height,
132+
out_width,
133+
pad_top,
134+
pad_left,
135+
value,
136+
out_data);
137+
}
138+
}
139+
}
140+
141+
} // namespace host
142+
} // namespace kernels
143+
} // namespace lite
144+
} // namespace paddle
145+
146+
REGISTER_LITE_KERNEL(pad2d,
147+
kHost,
148+
kFloat,
149+
kNCHW,
150+
paddle::lite::kernels::host::Pad2dCompute<float>,
151+
def)
152+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
153+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
154+
.Finalize();

0 commit comments

Comments
 (0)