Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions paddle/operators/box_coder_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2016 -> 2018.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/box_coder_op.h"

namespace paddle {
namespace operators {

class BoxCoderOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(PriorBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBoxVar"),
"Input(PriorBoxVar) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE(ctx->HasInput("TargetBox"),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"Input(TargetBox) of BoxCoderOp should not be null.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to check output var:

PADDLE_ENFORCE(ctx->HasOutput("OutputBox"), ...);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
auto target_box_dims = ctx->GetInputDim("TargetBox");

PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2UL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of prior_box_dims.size() is int, comparing with 2UL will generate warnings. Just 2 is ok. The same as below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"The rank of Input of PriorBox must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4UL,
"The shape of PriorBox is [N, 4]");
PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 2UL,
"The rank of Input of PriorBoxVar must be 2");
PADDLE_ENFORCE_EQ(prior_box_var_dims[1], 4UL,
"The shape of PriorBoxVar is [N, 4]");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change line 38 - line 41 to:

PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ENFORCE_EQ(target_box_dims.size(), 2UL,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4UL,
"The shape of TargetBox is [M, 4]");

GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));

ctx->SetOutputDim("OutputBox", framework::make_ddim({target_box_dims[0],
target_box_dims[1]}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape for OutputBox is not correct, it should be: {target_box_dims[0], prior_box_dims[0], 4}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to share LoD between TargetBox and TargetBox:

ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
};

class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"PriorBox",
"(Tensor, default Tensor<float>) "
"Box list PriorBox is a 2-D Tensor with shape [M, 4] holds N boxes, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N boxes -> M boxes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"each box is represented as [xmin, ymin, xmax, ymax], "
"[xmin, ymin] is the left top coordinate of the anchor box, "
"if the input is image feature map, they are close to the origin "
"of the coordinate system. [xmax, ymax] is the right bottom "
"coordinate of the anchor box.");
AddInput("PriorBoxVar",
"(Tensor, default Tensor<float>) "
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds N group "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M, 4] holds N group -> M, 4] holds M group

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"of variance.");
AddInput(
"TargetBox",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
"[N, 4], each box is represented as [xmin, ymin, xmax, ymax], "
"[xmin, ymin] is the left top coordinate of the box if the input "
"is image feature map, they are close to the origin of the coordinate "
"system. [xmax, ymax] is the right bottom coordinate of the box. "
"This tensor can contain LoD information to represent a batch "
"of inputs. One instance of this batch can contain different "
"numbers of entities.");
AddAttr<std::string>("code_type",
"(string, default encode_center_size) "
"the code type used with the target box")
.SetDefault("encode_center_size")
.InEnum({"encode_center_size", "decode_center_size"});
AddOutput(
"OutputBox",
"(Tensor, default Tensor<float>)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(LoDTensor or Tenosr), the type is the same with TargetBox

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] "
"representing the result of N target boxes encoded/decoded with "
"M Prior boxes and variances.");

AddComment(R"DOC(
Bounding Box Coder Operator.
Encode/Decode the priorbox information with the target bounding box.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more detailed description for encode and decode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker);
REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel<float>,
ops::BoxCoderKernel<double>);
145 changes: 145 additions & 0 deletions paddle/operators/box_coder_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/box_coder_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems PADDLE_CUDA_NUM_THREADS is not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, int row,
int col, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
const int col_idx = idx % col;
T prior_box_width =
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
T prior_box_height =
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
T prior_box_center_x =
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
T prior_box_center_y =
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;

T target_box_center_x =
(target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2;
T target_box_center_y =
(target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) /
2;
T target_box_width =
target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4];
T target_box_height =
target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1];

output[idx * 4] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[col_idx * 4];
output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height /
prior_box_var_data[col_idx * 4 + 1];
output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) /
prior_box_var_data[col_idx * 4 + 2];
output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) /
prior_box_var_data[col_idx * 4 + 3];
}
}

template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, int row,
int col, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
const int col_idx = idx % col;
T prior_box_width =
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
T prior_box_height =
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
T prior_box_center_x =
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
T prior_box_center_y =
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;

T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] *
target_box_data[row_idx * 4 + 2]) *
prior_box_width;
T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] *
target_box_data[row_idx * 4 + 3]) *
prior_box_height;
T target_box_center_x = prior_box_var_data[col_idx * 4] *
target_box_data[row_idx * 4] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] *
target_box_data[row_idx * 4 + 1] *
prior_box_height +
prior_box_center_y;

output[idx * 4] = target_box_center_x - target_box_width / 2;
output[idx * 4 + 1] = target_box_center_y - target_box_height / 2;
output[idx * 4 + 2] = target_box_center_x + target_box_width / 2;
output[idx * 4 + 3] = target_box_center_y + target_box_height / 2;
}
}

template <typename T>
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<Tensor>("OutputBox");

if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
"Only support 1 level of LoD.");
}
auto row = target_box->dims()[0];
auto col = prior_box->dims()[0];
int block = 512;
int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context();

const T* prior_box_data = prior_box->data<T>();
const T* prior_box_var_data = prior_box_var->data<T>();
const T* target_box_data = target_box->data<T>();

output_box->mutable_data<T>({row, col, 4}, context.GetPlace());
T* output = output_box->data<T>();

auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col,
output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col,
output);
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>,
ops::BoxCoderCUDAKernel<double>);
Loading