-
Notifications
You must be signed in to change notification settings - Fork 5.9k
add box coder op #7922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
add box coder op #7922
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
72eccb2
add box coder op
bc6c4db
Update box_coder_op.cc
Noplz 58bfaea
update according to the code review
02d2b7b
update according to the code review
c3e89f3
update accoding to the code review
e14272b
update accoding to the code review
251c2fd
Update according to the code review
7d8d9db
Update according to the code review
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/box_coder_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class BoxCoderOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| protected: | ||
| void InferShape(framework::InferShapeContext *ctx) const override { | ||
| PADDLE_ENFORCE(ctx->HasInput("PriorBox"), | ||
| "Input(PriorBox) of BoxCoderOp should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasInput("PriorBoxVar"), | ||
| "Input(PriorBoxVar) of BoxCoderOp should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasInput("PriorBox"), | ||
| "Input(TargetBox) of BoxCoderOp should not be null."); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also need to check output var:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
| auto prior_box_dims = ctx->GetInputDim("PriorBox"); | ||
| auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); | ||
| auto target_box_dims = ctx->GetInputDim("TargetBox"); | ||
|
|
||
| PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, | ||
| "The rank of Input of PriorBoxVar must be 2"); | ||
| PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); | ||
| PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); | ||
| PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, | ||
| "The rank of Input of TargetBox must be 2"); | ||
| PADDLE_ENFORCE_EQ(target_box_dims[1], 4, | ||
| "The shape of TargetBox is [M, 4]"); | ||
|
|
||
| GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type")); | ||
|
|
||
| ctx->SetOutputDim( | ||
| "OutputBox", | ||
| framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); | ||
| ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); | ||
| } | ||
| }; | ||
|
|
||
| class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput( | ||
| "PriorBox", | ||
| "(Tensor, default Tensor<float>) " | ||
| "Box list PriorBox is a 2-D Tensor with shape [M, 4] holds M boxes, " | ||
| "each box is represented as [xmin, ymin, xmax, ymax], " | ||
| "[xmin, ymin] is the left top coordinate of the anchor box, " | ||
| "if the input is image feature map, they are close to the origin " | ||
| "of the coordinate system. [xmax, ymax] is the right bottom " | ||
| "coordinate of the anchor box."); | ||
| AddInput("PriorBoxVar", | ||
| "(Tensor, default Tensor<float>) " | ||
| "PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group " | ||
| "of variance."); | ||
| AddInput( | ||
| "TargetBox", | ||
| "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " | ||
| "[N, 4], each box is represented as [xmin, ymin, xmax, ymax], " | ||
| "[xmin, ymin] is the left top coordinate of the box if the input " | ||
| "is image feature map, they are close to the origin of the coordinate " | ||
| "system. [xmax, ymax] is the right bottom coordinate of the box. " | ||
| "This tensor can contain LoD information to represent a batch " | ||
| "of inputs. One instance of this batch can contain different " | ||
| "numbers of entities."); | ||
| AddAttr<std::string>("code_type", | ||
| "(string, default encode_center_size) " | ||
| "the code type used with the target box") | ||
| .SetDefault("encode_center_size") | ||
| .InEnum({"encode_center_size", "decode_center_size"}); | ||
| AddOutput( | ||
| "OutputBox", | ||
| "(LoDTensor or Tensor) " | ||
| "(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] " | ||
| "representing the result of N target boxes encoded/decoded with " | ||
| "M Prior boxes and variances."); | ||
|
|
||
| AddComment(R"DOC( | ||
| Bounding Box Coder Operator. | ||
| Encode/Decode the target bounding box with the priorbox information. | ||
| The Encoding schema described below: | ||
| ox = (tx - px) / pw / pxv | ||
| oy = (ty - py) / ph / pyv | ||
| ow = log(abs(tw / pw)) / pwv | ||
| oh = log(abs(th / ph)) / phv | ||
| The Decoding schema described below: | ||
| ox = (pw * pxv * tx * + px) - tw / 2 | ||
| oy = (ph * pyv * ty * + py) - th / 2 | ||
| ow = exp(pwv * tw) * pw + tw / 2 | ||
| oh = exp(phv * th) * ph + th / 2 | ||
| where tx, ty, tw, th denote the target box's center coordinates, width and | ||
| height respectively. Similarly, px, py, pw, ph denote the priorbox's(anchor) | ||
| center coordinates, width and height. pxv, pyv, pwv, phv denote the variance | ||
| of the priorbox and ox, oy, ow, oh denote the encoded/decoded coordinates, | ||
| width and height. | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_WITHOUT_GRADIENT(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker); | ||
| REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel<float>, | ||
| ops::BoxCoderKernel<double>); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/box_coder_op.h" | ||
| #include "paddle/platform/cuda_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| template <typename T> | ||
| __global__ void EncodeCenterSizeKernel(const T* prior_box_data, | ||
| const T* prior_box_var_data, | ||
| const T* target_box_data, const int row, | ||
| const int col, const int len, | ||
| T* output) { | ||
| const int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||
| if (idx < row * col) { | ||
| const int row_idx = idx / col; | ||
| const int col_idx = idx % col; | ||
| T prior_box_width = | ||
| prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; | ||
| T prior_box_height = | ||
| prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; | ||
| T prior_box_center_x = | ||
| (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; | ||
| T prior_box_center_y = (prior_box_data[col_idx * len + 3] + | ||
| prior_box_data[col_idx * len + 1]) / | ||
| 2; | ||
|
|
||
| T target_box_center_x = | ||
| (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / | ||
| 2; | ||
| T target_box_center_y = (target_box_data[row_idx * len + 3] + | ||
| target_box_data[row_idx * len + 1]) / | ||
| 2; | ||
| T target_box_width = | ||
| target_box_data[row_idx * len + 2] - target_box_data[row_idx * len]; | ||
| T target_box_height = | ||
| target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1]; | ||
|
|
||
| output[idx * len] = (target_box_center_x - prior_box_center_x) / | ||
| prior_box_width / prior_box_var_data[col_idx * len]; | ||
| output[idx * len + 1] = (target_box_center_y - prior_box_center_y) / | ||
| prior_box_height / | ||
| prior_box_var_data[col_idx * len + 1]; | ||
| output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) / | ||
| prior_box_var_data[col_idx * len + 2]; | ||
| output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) / | ||
| prior_box_var_data[col_idx * len + 3]; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void DecodeCenterSizeKernel(const T* prior_box_data, | ||
| const T* prior_box_var_data, | ||
| const T* target_box_data, const int row, | ||
| const int col, const int len, | ||
| T* output) { | ||
| const int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||
| if (idx < row * col) { | ||
| const int row_idx = idx / col; | ||
| const int col_idx = idx % col; | ||
| T prior_box_width = | ||
| prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; | ||
| T prior_box_height = | ||
| prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; | ||
| T prior_box_center_x = | ||
| (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; | ||
| T prior_box_center_y = (prior_box_data[col_idx * len + 3] + | ||
| prior_box_data[col_idx * len + 1]) / | ||
| 2; | ||
|
|
||
| T target_box_width = exp(prior_box_var_data[col_idx * len + 2] * | ||
| target_box_data[row_idx * len + 2]) * | ||
| prior_box_width; | ||
| T target_box_height = exp(prior_box_var_data[col_idx * len + 3] * | ||
| target_box_data[row_idx * len + 3]) * | ||
| prior_box_height; | ||
| T target_box_center_x = prior_box_var_data[col_idx * len] * | ||
| target_box_data[row_idx * len] * | ||
| prior_box_width + | ||
| prior_box_center_x; | ||
| T target_box_center_y = prior_box_var_data[col_idx * len + 1] * | ||
| target_box_data[row_idx * len + 1] * | ||
| prior_box_height + | ||
| prior_box_center_y; | ||
|
|
||
| output[idx * len] = target_box_center_x - target_box_width / 2; | ||
| output[idx * len + 1] = target_box_center_y - target_box_height / 2; | ||
| output[idx * len + 2] = target_box_center_x + target_box_width / 2; | ||
| output[idx * len + 3] = target_box_center_y + target_box_height / 2; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| class BoxCoderCUDAKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), | ||
| "This kernel only runs on GPU device."); | ||
| auto* prior_box = context.Input<framework::Tensor>("PriorBox"); | ||
| auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); | ||
| auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); | ||
| auto* output_box = context.Output<Tensor>("OutputBox"); | ||
|
|
||
| if (target_box->lod().size()) { | ||
| PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, | ||
| "Only support 1 level of LoD."); | ||
| } | ||
| auto row = target_box->dims()[0]; | ||
| auto col = prior_box->dims()[0]; | ||
| auto len = prior_box->dims()[1]; | ||
| int block = 512; | ||
| int grid = (row * col + block - 1) / block; | ||
| auto& device_ctx = context.cuda_device_context(); | ||
|
|
||
| const T* prior_box_data = prior_box->data<T>(); | ||
| const T* prior_box_var_data = prior_box_var->data<T>(); | ||
| const T* target_box_data = target_box->data<T>(); | ||
|
|
||
| output_box->mutable_data<T>({row, col, len}, context.GetPlace()); | ||
| T* output = output_box->data<T>(); | ||
|
|
||
| auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); | ||
| if (code_type == BoxCodeType::kEncodeCenterSize) { | ||
| EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( | ||
| prior_box_data, prior_box_var_data, target_box_data, row, col, len, | ||
| output); | ||
| } else if (code_type == BoxCodeType::kDecodeCenterSize) { | ||
| DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( | ||
| prior_box_data, prior_box_var_data, target_box_data, row, col, len, | ||
| output); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>, | ||
| ops::BoxCoderCUDAKernel<double>); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done