Skip to content

Commit b7db353

Browse files
authored
Merge pull request #7922 from Noplz/box_coder_op
add box coder op
2 parents 4e7e39b + 7d8d9db commit b7db353

4 files changed

Lines changed: 549 additions & 0 deletions

File tree

paddle/operators/box_coder_op.cc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/operators/box_coder_op.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
class BoxCoderOp : public framework::OperatorWithKernel {
18+
public:
19+
using framework::OperatorWithKernel::OperatorWithKernel;
20+
21+
protected:
22+
void InferShape(framework::InferShapeContext *ctx) const override {
23+
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
24+
"Input(PriorBox) of BoxCoderOp should not be null.");
25+
PADDLE_ENFORCE(ctx->HasInput("PriorBoxVar"),
26+
"Input(PriorBoxVar) of BoxCoderOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("TargetBox"),
28+
"Input(TargetBox) of BoxCoderOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("OutputBox"),
30+
"Output(OutputBox) of BoxCoderOp should not be null.");
31+
32+
auto prior_box_dims = ctx->GetInputDim("PriorBox");
33+
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
34+
auto target_box_dims = ctx->GetInputDim("TargetBox");
35+
36+
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
37+
"The rank of Input of PriorBoxVar must be 2");
38+
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]");
39+
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
40+
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
41+
"The rank of Input of TargetBox must be 2");
42+
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
43+
"The shape of TargetBox is [M, 4]");
44+
45+
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
46+
47+
ctx->SetOutputDim(
48+
"OutputBox",
49+
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
50+
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
51+
}
52+
};
53+
54+
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
55+
public:
56+
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker)
57+
: OpProtoAndCheckerMaker(proto, op_checker) {
58+
AddInput(
59+
"PriorBox",
60+
"(Tensor, default Tensor<float>) "
61+
"Box list PriorBox is a 2-D Tensor with shape [M, 4] holds M boxes, "
62+
"each box is represented as [xmin, ymin, xmax, ymax], "
63+
"[xmin, ymin] is the left top coordinate of the anchor box, "
64+
"if the input is image feature map, they are close to the origin "
65+
"of the coordinate system. [xmax, ymax] is the right bottom "
66+
"coordinate of the anchor box.");
67+
AddInput("PriorBoxVar",
68+
"(Tensor, default Tensor<float>) "
69+
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group "
70+
"of variance.");
71+
AddInput(
72+
"TargetBox",
73+
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
74+
"[N, 4], each box is represented as [xmin, ymin, xmax, ymax], "
75+
"[xmin, ymin] is the left top coordinate of the box if the input "
76+
"is image feature map, they are close to the origin of the coordinate "
77+
"system. [xmax, ymax] is the right bottom coordinate of the box. "
78+
"This tensor can contain LoD information to represent a batch "
79+
"of inputs. One instance of this batch can contain different "
80+
"numbers of entities.");
81+
AddAttr<std::string>("code_type",
82+
"(string, default encode_center_size) "
83+
"the code type used with the target box")
84+
.SetDefault("encode_center_size")
85+
.InEnum({"encode_center_size", "decode_center_size"});
86+
AddOutput(
87+
"OutputBox",
88+
"(LoDTensor or Tensor) "
89+
"(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] "
90+
"representing the result of N target boxes encoded/decoded with "
91+
"M Prior boxes and variances.");
92+
93+
AddComment(R"DOC(
94+
Bounding Box Coder Operator.
95+
Encode/Decode the target bounding box with the priorbox information.
96+
The Encoding schema described below:
97+
ox = (tx - px) / pw / pxv
98+
oy = (ty - py) / ph / pyv
99+
ow = log(abs(tw / pw)) / pwv
100+
oh = log(abs(th / ph)) / phv
101+
The Decoding schema described below:
102+
ox = (pw * pxv * tx * + px) - tw / 2
103+
oy = (ph * pyv * ty * + py) - th / 2
104+
ow = exp(pwv * tw) * pw + tw / 2
105+
oh = exp(phv * th) * ph + th / 2
106+
where tx, ty, tw, th denote the target box's center coordinates, width and
107+
height respectively. Similarly, px, py, pw, ph denote the priorbox's(anchor)
108+
center coordinates, width and height. pxv, pyv, pwv, phv denote the variance
109+
of the priorbox and ox, oy, ow, oh denote the encoded/decoded coordinates,
110+
width and height.
111+
)DOC");
112+
}
113+
};
114+
115+
} // namespace operators
116+
} // namespace paddle
117+
118+
namespace ops = paddle::operators;
119+
REGISTER_OP_WITHOUT_GRADIENT(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker);
120+
REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel<float>,
121+
ops::BoxCoderKernel<double>);

paddle/operators/box_coder_op.cu

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/operators/box_coder_op.h"
13+
#include "paddle/platform/cuda_helper.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
template <typename T>
19+
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
20+
const T* prior_box_var_data,
21+
const T* target_box_data, const int row,
22+
const int col, const int len,
23+
T* output) {
24+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
25+
if (idx < row * col) {
26+
const int row_idx = idx / col;
27+
const int col_idx = idx % col;
28+
T prior_box_width =
29+
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
30+
T prior_box_height =
31+
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
32+
T prior_box_center_x =
33+
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
34+
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
35+
prior_box_data[col_idx * len + 1]) /
36+
2;
37+
38+
T target_box_center_x =
39+
(target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
40+
2;
41+
T target_box_center_y = (target_box_data[row_idx * len + 3] +
42+
target_box_data[row_idx * len + 1]) /
43+
2;
44+
T target_box_width =
45+
target_box_data[row_idx * len + 2] - target_box_data[row_idx * len];
46+
T target_box_height =
47+
target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1];
48+
49+
output[idx * len] = (target_box_center_x - prior_box_center_x) /
50+
prior_box_width / prior_box_var_data[col_idx * len];
51+
output[idx * len + 1] = (target_box_center_y - prior_box_center_y) /
52+
prior_box_height /
53+
prior_box_var_data[col_idx * len + 1];
54+
output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) /
55+
prior_box_var_data[col_idx * len + 2];
56+
output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) /
57+
prior_box_var_data[col_idx * len + 3];
58+
}
59+
}
60+
61+
template <typename T>
62+
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
63+
const T* prior_box_var_data,
64+
const T* target_box_data, const int row,
65+
const int col, const int len,
66+
T* output) {
67+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
68+
if (idx < row * col) {
69+
const int row_idx = idx / col;
70+
const int col_idx = idx % col;
71+
T prior_box_width =
72+
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
73+
T prior_box_height =
74+
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
75+
T prior_box_center_x =
76+
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
77+
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
78+
prior_box_data[col_idx * len + 1]) /
79+
2;
80+
81+
T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
82+
target_box_data[row_idx * len + 2]) *
83+
prior_box_width;
84+
T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
85+
target_box_data[row_idx * len + 3]) *
86+
prior_box_height;
87+
T target_box_center_x = prior_box_var_data[col_idx * len] *
88+
target_box_data[row_idx * len] *
89+
prior_box_width +
90+
prior_box_center_x;
91+
T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
92+
target_box_data[row_idx * len + 1] *
93+
prior_box_height +
94+
prior_box_center_y;
95+
96+
output[idx * len] = target_box_center_x - target_box_width / 2;
97+
output[idx * len + 1] = target_box_center_y - target_box_height / 2;
98+
output[idx * len + 2] = target_box_center_x + target_box_width / 2;
99+
output[idx * len + 3] = target_box_center_y + target_box_height / 2;
100+
}
101+
}
102+
103+
template <typename T>
104+
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
105+
public:
106+
void Compute(const framework::ExecutionContext& context) const override {
107+
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
108+
"This kernel only runs on GPU device.");
109+
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
110+
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
111+
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
112+
auto* output_box = context.Output<framework::Tensor>("OutputBox");
113+
114+
if (target_box->lod().size()) {
115+
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
116+
"Only support 1 level of LoD.");
117+
}
118+
auto row = target_box->dims()[0];
119+
auto col = prior_box->dims()[0];
120+
auto len = prior_box->dims()[1];
121+
int block = 512;
122+
int grid = (row * col + block - 1) / block;
123+
auto& device_ctx = context.cuda_device_context();
124+
125+
const T* prior_box_data = prior_box->data<T>();
126+
const T* prior_box_var_data = prior_box_var->data<T>();
127+
const T* target_box_data = target_box->data<T>();
128+
129+
output_box->mutable_data<T>({row, col, len}, context.GetPlace());
130+
T* output = output_box->data<T>();
131+
132+
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
133+
if (code_type == BoxCodeType::kEncodeCenterSize) {
134+
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
135+
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
136+
output);
137+
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
138+
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
139+
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
140+
output);
141+
}
142+
}
143+
};
144+
145+
} // namespace operators
146+
} // namespace paddle
147+
148+
namespace ops = paddle::operators;
149+
REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>,
150+
ops::BoxCoderCUDAKernel<double>);

0 commit comments

Comments
 (0)