diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h new file mode 100755 index 00000000000000..b9600ecff80cb3 --- /dev/null +++ b/paddle/operators/math/detection_util.h @@ -0,0 +1,332 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/framework/selected_rows.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/strided_memcpy.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct BBox { + BBox(T x_min, T y_min, T x_max, T y_max) + : x_min(x_min), + y_min(y_min), + x_max(x_max), + y_max(y_max), + is_difficult(false) {} + + BBox() {} + + T get_width() const { return x_max - x_min; } + + T get_height() const { return y_max - y_min; } + + T get_center_x() const { return (x_min + x_max) / 2; } + + T get_center_y() const { return (y_min + y_max) / 2; } + + T get_area() const { return get_width() * get_height(); } + + // coordinate of bounding box + T x_min; + T y_min; + T x_max; + T y_max; + // whether difficult object (e.g. object with heavy occlusion is difficult) + bool is_difficult; +}; + +template +void GetBBoxFromDetectData(const T* detect_data, const size_t num_bboxes, + std::vector& labels, std::vector& scores, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + labels.resize(out_offset + num_bboxes); + scores.resize(out_offset + num_bboxes); + bboxes.resize(out_offset + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + labels[out_offset + i] = *(detect_data + i * 7 + 1); + scores[out_offset + i] = *(detect_data + i * 7 + 2); + BBox bbox; + bbox.x_min = *(detect_data + i * 7 + 3); + bbox.y_min = *(detect_data + i * 7 + 4); + bbox.x_max = *(detect_data + i * 7 + 5); + bbox.y_max = *(detect_data + i * 7 + 6); + bboxes[out_offset + i] = bbox; + }; +} + +template +void GetBBoxFromLabelData(const T* label_data, const size_t num_bboxes, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + bboxes.resize(bboxes.size() + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + BBox bbox; + bbox.x_min = *(label_data + i * 6 + 1); + bbox.y_min = *(label_data + i * 6 + 2); + bbox.x_max = *(label_data + i * 6 + 3); + bbox.y_max = *(label_data + i * 6 + 4); + T is_difficult = *(label_data + i * 6 + 5); + if (std::abs(is_difficult - 0.0) < 1e-6) + bbox.is_difficult = false; + else + bbox.is_difficult = true; + bboxes[out_offset + i] = bbox; + } +} + +template +void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + bboxes.resize(bboxes.size() + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + BBox bbox; + bbox.x_min = *(prior_data + i * 8); + bbox.y_min = *(prior_data + i * 8 + 1); + bbox.x_max = *(prior_data + i * 8 + 2); + bbox.y_max = *(prior_data + i * 8 + 3); + bboxes[out_offset + i] = bbox; + } +} + +template +void GetBBoxVarFromPriorData(const T* prior_data, const size_t num, + std::vector>& var_vec) { + size_t out_offset = var_vec.size(); + var_vec.resize(var_vec.size() + num); + for (size_t i = 0; i < num; ++i) { + std::vector var; + var.push_back(*(prior_data + i * 8 + 4)); + var.push_back(*(prior_data + i * 8 + 5)); + var.push_back(*(prior_data + i * 8 + 6)); + var.push_back(*(prior_data + i * 8 + 7)); + var_vec[out_offset + i] = var; + } +} + +template +void EncodeBBoxWithVar(const BBox& prior_bbox, + const std::vector& prior_bbox_var, + const BBox& gt_bbox, std::vector& out_vec) { + T prior_bbox_width = prior_bbox.get_width(); + T prior_bbox_height = prior_bbox.get_height(); + T prior_bbox_center_x = prior_bbox.get_center_x(); + T prior_bbox_center_y = prior_bbox.get_center_y(); + + T gt_bbox_width = gt_bbox.get_width(); + T gt_bbox_height = gt_bbox.get_height(); + T gt_bbox_center_x = gt_bbox.get_center_x(); + T gt_bbox_center_y = gt_bbox.get_center_y(); + + out_vec.clear(); + out_vec.push_back((gt_bbox_center_x - prior_bbox_center_x) / + prior_bbox_width / prior_bbox_var[0]); + out_vec.push_back((gt_bbox_center_y - prior_bbox_center_y) / + prior_bbox_height / prior_bbox_var[1]); + out_vec.push_back(std::log(std::fabs(gt_bbox_width / prior_bbox_width)) / + prior_bbox_var[2]); + out_vec.push_back(std::log(std::fabs(gt_bbox_height / prior_bbox_height)) / + prior_bbox_var[3]); +} + +template +inline float JaccardOverlap(const BBox& bbox1, const BBox& bbox2) { + if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || + bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { + return 0.0; + } else { + float inter_x_min = std::max(bbox1.x_min, bbox2.x_min); + float inter_y_min = std::max(bbox1.y_min, bbox2.y_min); + float inter_x_max = std::min(bbox1.x_max, bbox2.x_max); + float inter_y_max = std::min(bbox1.y_max, bbox2.y_max); + + float inter_width = inter_x_max - inter_x_min; + float inter_height = inter_y_max - inter_y_min; + float inter_area = inter_width * inter_height; + + float bbox_area1 = bbox1.get_area(); + float bbox_area2 = bbox2.get_area(); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); + } +} + +template +void MatchBBox(const std::vector>& prior_bboxes, + const std::vector>& gt_bboxes, float overlap_threshold, + std::vector& match_indices, + std::vector& match_overlaps) { + std::map> overlaps; + size_t num_priors = prior_bboxes.size(); + size_t num_gts = gt_bboxes.size(); + + match_indices.clear(); + match_indices.resize(num_priors, -1); + match_overlaps.clear(); + match_overlaps.resize(num_priors, 0.0); + + // Store the positive overlap between predictions and ground truth + for (size_t i = 0; i < num_priors; ++i) { + for (size_t j = 0; j < num_gts; ++j) { + float overlap = JaccardOverlap(prior_bboxes[i], gt_bboxes[j]); + if (overlap > 1e-6) { + match_overlaps[i] = std::max(match_overlaps[i], overlap); + overlaps[i][j] = overlap; + } + } + } + // Bipartite matching + std::vector gt_pool; + for (size_t i = 0; i < num_gts; ++i) { + gt_pool.push_back(i); + } + while (gt_pool.size() > 0) { + // Find the most overlapped gt and corresponding predictions + int max_prior_idx = -1; + int max_gt_idx = -1; + float max_overlap = -1.0; + for (auto it = overlaps.begin(); it != overlaps.end(); ++it) { + size_t i = it->first; + if (match_indices[i] != -1) { + // The prediction already has matched ground truth or is ignored + continue; + } + for (size_t p = 0; p < gt_pool.size(); ++p) { + int j = gt_pool[p]; + if (it->second.find(j) == it->second.end()) { + // No overlap between the i-th prediction and j-th ground truth + continue; + } + // Find the maximum overlapped pair + if (it->second[j] > max_overlap) { + max_prior_idx = (int)i; + max_gt_idx = (int)j; + max_overlap = it->second[j]; + } + } + } + if (max_prior_idx == -1) { + break; + } else { + match_indices[max_prior_idx] = max_gt_idx; + match_overlaps[max_prior_idx] = max_overlap; + gt_pool.erase(std::find(gt_pool.begin(), gt_pool.end(), max_gt_idx)); + } + } + + // Get most overlaped for the rest prediction bboxes + for (auto it = overlaps.begin(); it != overlaps.end(); ++it) { + size_t i = it->first; + if (match_indices[i] != -1) { + // The prediction already has matched ground truth or is ignored + continue; + } + int max_gt_idx = -1; + float max_overlap = -1; + for (size_t j = 0; j < num_gts; ++j) { + if (it->second.find(j) == it->second.end()) { + // No overlap between the i-th prediction and j-th ground truth + continue; + } + // Find the maximum overlapped pair + float overlap = it->second[j]; + if (overlap > max_overlap && overlap >= overlap_threshold) { + max_gt_idx = j; + max_overlap = overlap; + } + } + if (max_gt_idx != -1) { + match_indices[i] = max_gt_idx; + match_overlaps[i] = max_overlap; + } + } +} + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +int TransposeFromNCHWToNHWC(const platform::Place& dst_place, + const DeviceContext& ctx, + const framework::Tensor& src, + framework::Tensor& dst, int dst_total_size, + int dst_offset) { + int batch_size = src.dims()[0]; + std::vector shape_vec( + {src.dims()[0], src.dims()[2], src.dims()[3], src.dims()[1]}); + auto shape = framework::make_ddim(shape_vec); + framework::Tensor src_transpose; + src_transpose.mutable_data(shape, dst_place); + std::vector shape_axis({0, 2, 3, 1}); + math::Transpose trans4; + trans4(ctx, src, &src_transpose, shape_axis); + + auto src_stride = framework::stride(src.dims()); + + for (int i = 0; i < batch_size; ++i) { + int out_offset = i * (dst_total_size / batch_size) + dst_offset; + framework::Tensor src_i = src_transpose.Slice(i, i + 1); + + src_i.Resize(framework::make_ddim({1, src_i.numel()})); + + StridedMemcpy(ctx, src_i.data(), framework::stride(src_i.dims()), + src_i.dims(), framework::stride(dst.dims()), + dst.data() + out_offset); + } + + return src_stride[0]; +} + +template +int TransposeFromNHWCToNCHW(const platform::Place& dst_place, + const DeviceContext& ctx, + const framework::Tensor& src, int src_total_size, + int src_offset, framework::Tensor& dst) { + int batch_size = dst.dims()[0]; + + framework::Tensor dst_transpose; + std::vector shape_vec( + {dst.dims()[0], dst.dims()[3], dst.dims()[1], dst.dims()[2]}); + dst_transpose.mutable_data(framework::make_ddim(shape_vec), dst_place); + auto dst_stride = framework::stride(dst.dims()); + + for (int i = 0; i < batch_size; ++i) { + int in_offset = i * (src_total_size / batch_size) + src_offset; + framework::Tensor dst_i = dst_transpose.Slice(i, i + 1); + + dst_i.Resize(framework::make_ddim({1, dst_stride[0]})); + StridedMemcpy(ctx, src.data() + in_offset, + framework::stride(src.dims()), src.dims(), + framework::stride(dst_i.dims()), dst_i.data()); + } + + std::vector shape_axis({0, 3, 1, 2}); + math::Transpose trans4; + trans4(ctx, dst_transpose, &dst, shape_axis); + return dst_stride[0]; +} + +} // namespace math + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/multi_box_loss_op.cc b/paddle/operators/multi_box_loss_op.cc new file mode 100644 index 00000000000000..84aba2f0969dc6 --- /dev/null +++ b/paddle/operators/multi_box_loss_op.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/multi_box_loss_op.h" + +namespace paddle { +namespace operators { + +class MultiBoxLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("Loc"), + "Inputs(Loc) of MultiBoxLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInputs("Conf"), + "Inputs(Conf) of MultiBoxLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of MultiBoxLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of MultiBoxLossOp should not be null."); + + PADDLE_ENFORCE_EQ(ctx->Inputs("Loc").size(), ctx->Inputs("Conf").size(), + "The input number of Loc and Conf should be the same."); + + int input_num = ctx->Inputs("Loc").size(); + auto loc_dims = ctx->GetInputsDim("Loc"); + auto conf_dims = ctx->GetInputsDim("Conf"); + for (int i = 0; i < input_num; ++i) { + PADDLE_ENFORCE_EQ(loc_dims[i].size(), 4UL, + "The format of input(loc %d) tensor is NCHW.", i); + PADDLE_ENFORCE_EQ(conf_dims[i].size(), 4UL, + "The format of input(conf %d) tensor is NCHW.", i); + } + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Label").size(), 2UL, + "The dim size of input(label) tensor is 2."); + + auto loss_dims = framework::make_ddim({1}); + ctx->SetOutputDim("Loss", loss_dims); + auto couter_dims = framework::make_ddim({3}); + ctx->SetOutputDim("InterCounter", couter_dims); + } + + protected: + framework::OpKernelType GetActualKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Label")->type()), + ctx.device_context()); + } +}; + +class MultiBoxLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputsDim(framework::GradVarName("Loc"), ctx->GetInputsDim("Loc")); + ctx->SetOutputsDim(framework::GradVarName("Conf"), + ctx->GetInputsDim("Conf")); + } + + protected: + framework::OpKernelType GetActualKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Label")->type()), + ctx.device_context()); + } +}; + +class MultiBoxLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MultiBoxLossOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Loc", "The input predict locations.").AsDuplicable(); + AddInput("Conf", "The input priorbox confidence..").AsDuplicable(); + AddInput("PriorBox", "The input priorbox location."); + AddInput("Label", "The input label."); + AddOutput("Loss", "The output loss."); + AddOutput("InterCounter", "Internal use counter.").AsIntermediate(); + AddOutput("AllMatchIndices", "All match indices, internal use only.") + .AsIntermediate(); + AddOutput("AllNegIndices", "All negative indices, internal use only.") + .AsIntermediate(); + AddOutput("LocGTData", "Locations ground truth data, internal use only.") + .AsIntermediate(); + AddOutput("ConfGTData", "Confidence ground truth data, internal use only.") + .AsIntermediate(); + AddOutput("LocDiff", "Locations difference data, internal use only.") + .AsIntermediate(); + AddOutput("ConfProb", "Confidence possibility data, internal use only.") + .AsIntermediate(); + AddAttr("class_num", "The number of the classification.") + .SetDefault(0); + AddAttr("overlap_threshold", "The threshold of the overlap.") + .SetDefault(0.5); + AddAttr("neg_overlap", "The negative bbox overlap threshold.") + .SetDefault(0.5); + AddAttr("neg_pos_ratio", + "The ratio of the negative bbox to the positive bbox.") + .SetDefault(3); + AddAttr("background_label_id", "The background class index.") + .SetDefault(-1); + AddComment(R"DOC( +MultiBoxLoss operator +Compute the location loss and the confidence loss for ssd. +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_EX(multi_box_loss, ops::MultiBoxLossOp, ops::MultiBoxLossOpMaker, + multi_box_loss_grad, ops::MultiBoxLossGradOp, false); +REGISTER_OP_CPU_KERNEL( + multi_box_loss, + ops::MultiBoxLossOpKernel); +REGISTER_OP_CPU_KERNEL( + multi_box_loss_grad, + ops::MultiBoxLossGradOpKernel); diff --git a/paddle/operators/multi_box_loss_op.cu b/paddle/operators/multi_box_loss_op.cu new file mode 100755 index 00000000000000..3d202de5fe0b08 --- /dev/null +++ b/paddle/operators/multi_box_loss_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/multi_box_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + multi_box_loss, + ops::MultiBoxLossOpKernel); +REGISTER_OP_CUDA_KERNEL( + multi_box_loss_grad, + ops::MultiBoxLossGradOpKernel); diff --git a/paddle/operators/multi_box_loss_op.h b/paddle/operators/multi_box_loss_op.h new file mode 100755 index 00000000000000..c0a7dc0d3d459f --- /dev/null +++ b/paddle/operators/multi_box_loss_op.h @@ -0,0 +1,685 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/cross_entropy.h" +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +template +T MultiBoxLossSmoothL1(const framework::ExecutionContext& ctx, + const framework::Tensor& output, + const framework::Tensor& label, int match_num, + T dest_scale) { + auto sample_num = output.dims()[0]; + auto dim = output.numel() / sample_num; + + PADDLE_ENFORCE_EQ(label.dims()[0], sample_num); + PADDLE_ENFORCE_EQ(label.numel(), output.numel()); + + const T* out_data = output.data(); + const T* label_data = label.data(); + + T cost = 0.0; + for (int i = 0; i < sample_num; ++i, out_data += dim, label_data += dim) { + T cost_i = 0.0; + for (int j = 0; j < dim; ++j) { + T abs = std::fabs(out_data[j] - label_data[j]); + cost_i *= dest_scale; + if (abs < 1.0) + cost_i += 0.5 * abs * abs; + else + cost_i += abs - 0.5; + } + cost += cost_i; + } + return cost / match_num; +} + +template +class MultiBoxLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins_loc = ctx.MultiInput("Loc"); + auto ins_conf = ctx.MultiInput("Conf"); + auto* in_priorbox = ctx.Input("PriorBox"); + auto* in_label = ctx.Input("Label"); + + auto* out_loss = ctx.Output("Loss"); + auto* out_inter_couter = ctx.Output("InterCounter"); + auto* out_all_match_indices = + ctx.Output("AllMatchIndices"); + + auto* out_all_neg_indices = ctx.Output("AllNegIndices"); + + auto* out_loc_gt = ctx.Output("LocGTData"); + auto* out_conf_gt = ctx.Output("ConfGTData"); + auto* out_loc_diff = ctx.Output("LocDiff"); + auto* out_conf_prob = ctx.Output("ConfProb"); + + int class_num = ctx.template Attr("class_num"); + float overlap_threshold = ctx.template Attr("overlap_threshold"); + float neg_pos_ratio = ctx.template Attr("neg_pos_ratio"); + float neg_overlap = ctx.template Attr("neg_overlap"); + int background_label_id = ctx.template Attr("background_label_id"); + + int input_num = ins_loc.size(); + int batch_size = ins_loc[0]->dims()[0]; + int prior_num = in_priorbox->numel() / 8; + + platform::CPUPlace cpu_place; + platform::CPUDeviceContext cpu_ctx(cpu_place); + + framework::Tensor loc_buffer_cpu; + framework::Tensor conf_buffer_cpu; + + int loc_size_sum = 0; + int conf_size_sum = 0; + for (int i = 0; i < input_num; ++i) { + loc_size_sum += ins_loc[i]->numel(); + conf_size_sum += ins_conf[i]->numel(); + } + + PADDLE_ENFORCE_EQ( + conf_size_sum, batch_size * prior_num * class_num, + "Sum of the sizes of inputs(conf) and batch_size * prior_num * " + "class_num must be the same."); + + auto loc_buffer_dim = framework::make_ddim({1, loc_size_sum}); + loc_buffer_cpu.mutable_data(loc_buffer_dim, platform::CPUPlace()); + + auto conf_buffer_dim = framework::make_ddim({1, conf_size_sum}); + conf_buffer_cpu.mutable_data(conf_buffer_dim, platform::CPUPlace()); + + math::SetConstant set_constant; + set_constant(cpu_ctx, &loc_buffer_cpu, 0); + set_constant(cpu_ctx, &conf_buffer_cpu, 0); + + int loc_offset = 0; + int conf_offset = 0; + + for (int i = 0; i < input_num; ++i) { + auto in_loc = ins_loc[i]; + auto in_conf = ins_conf[i]; + + framework::Tensor loc; + framework::Tensor conf; + + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::CopyFrom(*in_loc, platform::CPUPlace(), ctx.device_context(), + &loc); + framework::CopyFrom(*in_conf, platform::CPUPlace(), + ctx.device_context(), &conf); + + loc_offset += + math::TransposeFromNCHWToNHWC( + platform::CPUPlace(), cpu_ctx, loc, loc_buffer_cpu, + loc_size_sum, loc_offset); + conf_offset += + math::TransposeFromNCHWToNHWC( + platform::CPUPlace(), cpu_ctx, conf, conf_buffer_cpu, + conf_size_sum, conf_offset); + } else { + loc_offset += + math::TransposeFromNCHWToNHWC( + platform::CPUPlace(), cpu_ctx, *in_loc, loc_buffer_cpu, + loc_size_sum, loc_offset); + conf_offset += + math::TransposeFromNCHWToNHWC( + platform::CPUPlace(), cpu_ctx, *in_conf, conf_buffer_cpu, + conf_size_sum, conf_offset); + } + } + + std::vector> all_max_conf_score; + GetMaxConfidenceScores(conf_buffer_cpu, batch_size, prior_num, class_num, + background_label_id, all_max_conf_score); + + out_all_match_indices->mutable_data({batch_size, prior_num}, + platform::CPUPlace()); + + out_all_neg_indices->mutable_data({batch_size, prior_num}, + platform::CPUPlace()); + + auto all_match_indices = + framework::EigenMatrix::From(*out_all_match_indices); + all_match_indices.setConstant(-1); + + auto all_neg_indices = + framework::EigenMatrix::From(*out_all_neg_indices); + all_neg_indices.setConstant(-1); + + int total_match = 0; + int total_neg = 0; + framework::Tensor priorbox_cpu; + framework::LoDTensor label_cpu; + if (platform::is_gpu_place(ctx.GetPlace())) { + priorbox_cpu.mutable_data(in_priorbox->dims(), platform::CPUPlace()); + label_cpu.mutable_data(in_label->dims(), platform::CPUPlace()); + framework::CopyFrom(*in_priorbox, platform::CPUPlace(), + ctx.device_context(), &priorbox_cpu); + framework::CopyFrom(*in_label, platform::CPUPlace(), ctx.device_context(), + &label_cpu); + label_cpu.set_lod(in_label->lod()); + + GenerateMatchIndices(priorbox_cpu, prior_num, label_cpu, + all_max_conf_score, batch_size, overlap_threshold, + neg_overlap, neg_pos_ratio, all_match_indices, + all_neg_indices, total_match, total_neg); + } else { + GenerateMatchIndices(*in_priorbox, prior_num, *in_label, + all_max_conf_score, batch_size, overlap_threshold, + neg_overlap, neg_pos_ratio, all_match_indices, + all_neg_indices, total_match, total_neg); + } + + T loc_loss = 0.0; + T conf_loss = 0.0; + int num_conf = total_match + total_neg; + if (platform::is_gpu_place(ctx.GetPlace())) { + if (total_match >= 1) { + loc_loss = + CalcLocationLoss(ctx, priorbox_cpu, loc_buffer_cpu, label_cpu, + total_match, batch_size, prior_num, + all_match_indices, *out_loc_gt, *out_loc_diff); + } + + if (num_conf >= 1) { + conf_loss = CalcConfidenceLoss( + ctx, priorbox_cpu, conf_buffer_cpu, label_cpu, total_match, + total_neg, batch_size, prior_num, background_label_id, class_num, + all_match_indices, all_neg_indices, *out_conf_gt, *out_conf_prob); + } + } else { + if (total_match >= 1) { + loc_loss = + CalcLocationLoss(ctx, *in_priorbox, loc_buffer_cpu, *in_label, + total_match, batch_size, prior_num, + all_match_indices, *out_loc_gt, *out_loc_diff); + } + + if (num_conf >= 1) { + conf_loss = CalcConfidenceLoss( + ctx, *in_priorbox, conf_buffer_cpu, *in_label, total_match, + total_neg, batch_size, prior_num, background_label_id, class_num, + all_match_indices, all_neg_indices, *out_conf_gt, *out_conf_prob); + } + } + + T loss = loc_loss + conf_loss; + + out_loss->mutable_data(ctx.GetPlace()); + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor loss_cpu; + T* loss_data = + loss_cpu.mutable_data(out_loss->dims(), platform::CPUPlace()); + loss_data[0] = loss; + framework::CopyFrom(loss_cpu, ctx.GetPlace(), ctx.device_context(), + out_loss); + } else { + T* loss_data = out_loss->mutable_data(ctx.GetPlace()); + loss_data[0] = loss; + } + out_inter_couter->mutable_data(platform::CPUPlace()); + auto inter_couter = framework::EigenVector::Flatten(*out_inter_couter); + + inter_couter(0) = total_match; + inter_couter(1) = total_neg; + inter_couter(2) = num_conf; + } + + private: + void GetMaxConfidenceScores( + const framework::Tensor& conf, int batch_size, int prior_num, + int class_num, int background_label_id, + std::vector>& all_max_conf_score) const { + all_max_conf_score.clear(); + const T* conf_data = conf.data(); + for (int i = 0; i < batch_size; ++i) { + std::vector max_conf_score; + for (int j = 0; j < prior_num; ++j) { + int offset = j * class_num; + T max_val = -std::numeric_limits::max(); + T max_pos_val = -std::numeric_limits::max(); + T max_score = 0.0; + for (int c = 0; c < class_num; ++c) { + max_val = std::max(conf_data[offset + c], max_val); + if (c != background_label_id) + max_pos_val = std::max(conf_data[offset + c], max_pos_val); + } + T sum = 0.0; + for (int c = 0; c < class_num; ++c) + sum += std::exp(conf_data[offset + c] - max_val); + max_score = std::exp(max_pos_val - max_val) / sum; + max_conf_score.push_back(max_score); + } + conf_data += prior_num * class_num; + all_max_conf_score.push_back(max_conf_score); + } + return; + } + + void GenerateMatchIndices( + const framework::Tensor& priorbox, int num_prior_bboxes, + const framework::LoDTensor& label, + const std::vector>& max_conf_score, int batch_size, + float overlap_threshold, float neg_overlap_threshold, int neg_pos_ratio, + framework::EigenMatrix::Type& all_match_indices, + framework::EigenMatrix::Type& all_neg_indices, int& total_match, + int& total_neg) const { + std::vector> prior_bboxes; + GetBBoxFromPriorData(priorbox.data(), num_prior_bboxes, prior_bboxes); + + auto label_lod = label.lod(); + auto label_index = label_lod[0]; + auto label_data_num = static_cast(label_lod[0].size()); + + total_match = 0; + total_neg = 0; + for (int n = 0; n < batch_size; ++n) { + std::vector match_indices; + std::vector neg_indices; + std::vector match_overlaps; + match_indices.resize(num_prior_bboxes, -1); + match_overlaps.resize(num_prior_bboxes, 0.0); + size_t num_gt_bboxes = 0; + if (n < label_data_num) + num_gt_bboxes = label_index[n + 1] - label_index[n]; + if (num_gt_bboxes == 0) { + continue; + } + std::vector> gt_bboxes; + GetBBoxFromLabelData(label.data() + label_index[n] * 6, num_gt_bboxes, + gt_bboxes); + + MatchBBox(prior_bboxes, gt_bboxes, overlap_threshold, match_indices, + match_overlaps); + + size_t num_pos = 0; + size_t neg_num = 0; + for (size_t i = 0; i < match_indices.size(); ++i) + if (match_indices[i] != -1) ++num_pos; + total_match += num_pos; + std::vector> scores_indices; + for (size_t i = 0; i < match_indices.size(); ++i) + if (match_indices[i] == -1 && + match_overlaps[i] < neg_overlap_threshold) { + scores_indices.push_back(std::make_pair(max_conf_score[n][i], i)); + ++neg_num; + } + neg_num = std::min(static_cast(num_pos * neg_pos_ratio), neg_num); + std::sort(scores_indices.begin(), scores_indices.end(), + math::SortScorePairDescend); + for (size_t i = 0; i < neg_num; ++i) + neg_indices.push_back(scores_indices[i].second); + total_neg += neg_num; + for (size_t i = 0; i < match_indices.size(); ++i) { + all_match_indices(n, i) = match_indices[i]; + } + + for (size_t i = 0; i < neg_indices.size(); ++i) { + all_neg_indices(n, i) = neg_indices[i]; + } + } + return; + } + + T CalcLocationLoss(const framework::ExecutionContext& ctx, + const framework::Tensor& priorbox, + const framework::Tensor& loc_buffer, + const framework::LoDTensor& label, int match_num, + int batch_size, int prior_num, + framework::EigenMatrix::Type& all_match_indices, + framework::Tensor& loc_gt, + framework::Tensor& loc_diff) const { + T loc_loss = 0.0; + auto label_lod = label.lod(); + auto label_index = label_lod[0]; + + size_t count = 0; + auto loc_dim = framework::make_ddim({match_num * 4, 1}); + T* loc_diff_data = loc_diff.mutable_data(loc_dim, platform::CPUPlace()); + T* loc_gt_data = loc_gt.mutable_data(loc_dim, platform::CPUPlace()); + + int loc_gt_offset = 0; + const T* loc_buffer_data = loc_buffer.data(); + for (int n = 0; n < batch_size; ++n) { + for (int i = 0; i < prior_num; ++i) { + if (all_match_indices(n, i) == -1) continue; // match none + size_t loc_offset = n * (loc_buffer.numel() / batch_size) + i * 4; + std::copy(loc_buffer_data + loc_offset, + loc_buffer_data + loc_offset + 4, loc_diff_data + count); + count += 4; + const int gt_idx = all_match_indices(n, i); + size_t prior_offset = i * 8; + std::vector> prior_bboxes; + GetBBoxFromPriorData(priorbox.data() + prior_offset, 1, + prior_bboxes); + std::vector> prior_bbox_var; + math::GetBBoxVarFromPriorData(priorbox.data() + prior_offset, 1, + prior_bbox_var); + size_t label_offset = (label_index[n] + gt_idx) * 6; + std::vector> gt_bboxes; + GetBBoxFromLabelData(label.data() + label_offset, 1, gt_bboxes); + std::vector gt_encode; + EncodeBBoxWithVar(prior_bboxes[0], prior_bbox_var[0], gt_bboxes[0], + gt_encode); + std::copy(gt_encode.begin(), gt_encode.end(), + loc_gt_data + loc_gt_offset); + loc_gt_offset += gt_encode.size(); + } + } + loc_loss = MultiBoxLossSmoothL1(ctx, loc_diff, loc_gt, match_num, 0.0); + return loc_loss; + } + + T CalcConfidenceLoss(const framework::ExecutionContext& ctx, + const framework::Tensor& priorbox, + const framework::Tensor& conf_buffer, + const framework::LoDTensor& label, int match_num, + int neg_num, int batch_size, int prior_num, + int background_label_id, int class_num, + framework::EigenMatrix::Type& all_match_indices, + framework::EigenMatrix::Type& all_neg_indices, + framework::Tensor& conf_gt, + framework::Tensor& conf_prob) const { + T conf_loss = 0; + auto label_lod = label.lod(); + auto label_index = label_lod[0]; + + size_t count = 0; + T* conf_prob_data = conf_prob.mutable_data( + {match_num + neg_num, class_num}, platform::CPUPlace()); + int64_t* conf_gt_data = conf_gt.mutable_data( + {match_num + neg_num, 1}, platform::CPUPlace()); + const T* conf_buffer_data = conf_buffer.data(); + + platform::CPUPlace cpu_place; + platform::CPUDeviceContext cpu_ctx(cpu_place); + + math::SetConstant set_constant_t; + math::SetConstant set_constant_i; + set_constant_t(cpu_ctx, &conf_prob, 0); + set_constant_i(cpu_ctx, &conf_gt, 0); + + for (int n = 0; n < batch_size; ++n) { + for (int i = 0; i < prior_num; ++i) { + if (all_match_indices(n, i) == -1) continue; + size_t label_offset = (label_index[n] + all_match_indices(n, i)) * 6; + const int gt_label = (label.data() + label_offset)[0]; + + conf_gt_data[count] = gt_label; + size_t conf_offset = n * prior_num * class_num + i * class_num; + std::copy(conf_buffer_data + conf_offset, + conf_buffer_data + conf_offset + class_num, + conf_prob_data + count * class_num); + ++count; + } + // Negative mining samples + for (int i = 0; i < prior_num; ++i) { + if (all_neg_indices(n, i) == -1) continue; + conf_gt_data[count] = background_label_id; + size_t conf_offset = + n * prior_num * class_num + all_neg_indices(n, i) * class_num; + std::copy(conf_buffer_data + conf_offset, + conf_buffer_data + conf_offset + class_num, + conf_prob_data + count * class_num); + ++count; + } + } + + math::SoftmaxFunctor()(cpu_ctx, &conf_prob, + &conf_prob); + + framework::Tensor conf_loss_out; + auto conf_loss_data = conf_loss_out.mutable_data( + {match_num + neg_num, 1}, platform::CPUPlace()); + + math::CrossEntropyFunctor()( + cpu_ctx, &conf_loss_out, &conf_prob, &conf_gt, false); + + conf_loss = 0.0; + for (int i = 0; i < conf_loss_out.numel(); ++i) { + conf_loss += conf_loss_data[i]; + } + conf_loss = conf_loss / match_num; + return conf_loss; + } +}; // namespace operators + +template +void MultiBoxLossSmoothL1BP(const framework::ExecutionContext& ctx, + const framework::Tensor& output, + const framework::Tensor& label, + framework::Tensor& grad, int match_num, + T dest_scale) { + auto sample_num = output.dims()[0]; + auto dim = output.numel() / sample_num; + + const T* out_data = output.data(); + const T* label_data = label.data(); + T* grad_data = grad.mutable_data(platform::CPUPlace()); + + for (int i = 0; i < sample_num; + ++i, out_data += dim, grad_data += dim, label_data += dim) { + for (int j = 0; j < dim; ++j) { + T val = out_data[j] - label_data[j]; + grad_data[j] *= dest_scale; + if (std::fabs(val) < 1) { + grad_data[j] += val; + } else { + grad_data[j] += (T(0) < val) - (val < T(0)); + } + } + } +} + +template +class MultiBoxLossGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto d_ins_loc = + ctx.MultiOutput(framework::GradVarName("Loc")); + auto d_ins_conf = + ctx.MultiOutput(framework::GradVarName("Conf")); + + auto ins_loc = ctx.MultiInput("Loc"); + auto ins_conf = ctx.MultiInput("Conf"); + auto* in_priorbox = ctx.Input("PriorBox"); + + auto* in_loc_gt = ctx.Input("LocGTData"); + auto* in_conf_gt = ctx.Input("ConfGTData"); + auto* in_loc_diff = ctx.Input("LocDiff"); + auto* in_conf_prob = ctx.Input("ConfProb"); + + auto* in_inter_couter = ctx.Input("InterCounter"); + auto inter_couter = framework::EigenVector::Flatten(*in_inter_couter); + int class_num = ctx.template Attr("class_num"); + + auto* in_all_match_indices = + ctx.Input("AllMatchIndices"); + + auto* in_all_neg_indices = ctx.Input("AllNegIndices"); + + int batch_size = ins_loc[0]->dims()[0]; + int match_num = inter_couter(0); + // int neg_num = inter_couter(1); + int conf_num = inter_couter(2); + int prior_num = in_priorbox->numel() / 8; + int input_num = ins_loc.size(); + + framework::Tensor loc_buffer; + framework::Tensor conf_buffer; + + int loc_size_sum = 0; + int conf_size_sum = 0; + for (int i = 0; i < input_num; ++i) { + loc_size_sum += ins_loc[i]->numel(); + conf_size_sum += ins_conf[i]->numel(); + } + + auto loc_buffer_dim = framework::make_ddim({1, loc_size_sum}); + loc_buffer.mutable_data(loc_buffer_dim, platform::CPUPlace()); + + auto conf_buffer_dim = framework::make_ddim({1, conf_size_sum}); + conf_buffer.mutable_data(conf_buffer_dim, platform::CPUPlace()); + + platform::CPUPlace cpu_place; + platform::CPUDeviceContext cpu_ctx(cpu_place); + + math::SetConstant set_constant; + set_constant(cpu_ctx, &loc_buffer, 0); + set_constant(cpu_ctx, &conf_buffer, 0); + + auto all_match_indices = + framework::EigenMatrix::From(*in_all_match_indices); + + auto all_neg_indices = + framework::EigenMatrix::From(*in_all_neg_indices); + + if (match_num > 1) { + CalcLocationLossBP(ctx, *in_loc_diff, *in_loc_gt, loc_buffer, match_num, + batch_size, prior_num, all_match_indices); + } + + if (conf_num > 1) { + CalcConfidenceLossBP(ctx, *in_conf_gt, *in_conf_prob, conf_buffer, + match_num, conf_num, batch_size, prior_num, + class_num, all_match_indices, all_neg_indices); + } + + int loc_offset = 0; + int conf_offset = 0; + + for (int i = 0; i < input_num; ++i) { + auto d_loc = d_ins_loc[i]; + auto d_conf = d_ins_conf[i]; + + d_loc->mutable_data(ins_loc[i]->dims(), ctx.GetPlace()); + d_conf->mutable_data(ins_conf[i]->dims(), ctx.GetPlace()); + + math::SetConstant set_constant; + set_constant(ctx.template device_context(), d_loc, 0); + set_constant(ctx.template device_context(), d_conf, 0); + + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor d_loc_cpu; + framework::Tensor d_conf_cpu; + d_loc_cpu.mutable_data(d_loc->dims(), platform::CPUPlace()); + d_conf_cpu.mutable_data(d_conf->dims(), platform::CPUPlace()); + + math::SetConstant set_constant; + set_constant(cpu_ctx, &d_loc_cpu, 0); + set_constant(cpu_ctx, &d_conf_cpu, 0); + loc_offset += + math::TransposeFromNHWCToNCHW( + platform::CPUPlace(), cpu_ctx, loc_buffer, loc_size_sum, + loc_offset, d_loc_cpu); + conf_offset += + math::TransposeFromNHWCToNCHW( + platform::CPUPlace(), cpu_ctx, conf_buffer, conf_size_sum, + conf_offset, d_conf_cpu); + + framework::CopyFrom(d_loc_cpu, ctx.GetPlace(), ctx.device_context(), + d_loc); + framework::CopyFrom(d_conf_cpu, ctx.GetPlace(), ctx.device_context(), + d_conf); + + } else { + loc_offset += + math::TransposeFromNHWCToNCHW( + platform::CPUPlace(), cpu_ctx, loc_buffer, loc_size_sum, + loc_offset, *d_loc); + conf_offset += + math::TransposeFromNHWCToNCHW( + platform::CPUPlace(), cpu_ctx, conf_buffer, conf_size_sum, + conf_offset, *d_conf); + } + } + } + + private: + void CalcLocationLossBP( + const framework::ExecutionContext& ctx, const framework::Tensor& loc_diff, + const framework::Tensor& loc_gt, framework::Tensor& loc_buffer, + int match_num, int batch_size, int prior_num, + framework::EigenMatrix::Type& all_match_indices) const { + framework::Tensor loc_diff_buffer; + loc_diff_buffer.mutable_data(loc_diff.dims(), platform::CPUPlace()); + MultiBoxLossSmoothL1BP(ctx, loc_diff, loc_gt, loc_diff_buffer, match_num, + 0.0); + // scale gradient + auto loc_diff_data = loc_diff_buffer.data(); + for (int i = 0; i < match_num * 4; ++i) + loc_diff_data[i] *= (1. / match_num); + // Copy gradient back + size_t count = 0; + for (int n = 0; n < batch_size; ++n) { + for (int i = 0; i < prior_num; ++i) { + if (all_match_indices(n, i) == -1) continue; + T* loc_buffer_data = loc_buffer.data() + n * prior_num * 4 + i * 4; + std::copy(loc_diff_data + count * 4, loc_diff_data + (count + 1) * 4, + loc_buffer_data); + ++count; + } + } + } + + void CalcConfidenceLossBP( + const framework::ExecutionContext& ctx, const framework::Tensor& conf_gt, + const framework::Tensor& conf_prob, framework::Tensor& conf_buffer, + int match_num, int conf_num, int batch_size, int prior_num, int class_num, + framework::EigenMatrix::Type& all_match_indices, + framework::EigenMatrix::Type& all_neg_indices) const { + framework::Tensor conf_prob_temp; + conf_prob_temp.mutable_data(conf_prob.dims(), platform::CPUPlace()); + framework::CopyFrom(conf_prob, platform::CPUPlace(), ctx.device_context(), + &conf_prob_temp); + auto conf_prob_data = conf_prob_temp.data(); + auto conf_gt_data = conf_gt.data(); + for (int i = 0; i < conf_num; ++i) + conf_prob_data[i * class_num + conf_gt_data[i]] -= 1; + + for (int i = 0; i < conf_num * class_num; ++i) + conf_prob_data[i] *= (1. / match_num); + size_t count = 0; + + for (int n = 0; n < batch_size; ++n) { + for (int i = 0; i < prior_num; ++i) { + if (all_match_indices(n, i) == -1) continue; + T* conf_diff_data = + conf_buffer.data() + n * prior_num * class_num + i * class_num; + std::copy(conf_prob_data + count * class_num, + conf_prob_data + (count + 1) * class_num, conf_diff_data); + ++count; + } + for (int i = 0; i < prior_num; ++i) { + if (all_neg_indices(n, i) == -1) continue; + int idx = all_neg_indices(n, i); + T* conf_diff_data = + conf_buffer.data() + n * prior_num * class_num + idx * class_num; + std::copy(conf_prob_data + count * class_num, + conf_prob_data + (count + 1) * class_num, conf_diff_data); + ++count; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_multi_box_loss_op.py b/python/paddle/v2/fluid/tests/test_multi_box_loss_op.py new file mode 100755 index 00000000000000..39418abd2dbbdf --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_multi_box_loss_op.py @@ -0,0 +1,120 @@ +import unittest +import numpy as np +import sys +from op_test import OpTest + + +class TestMultiBoxLossOp(OpTest): + def set_data(self): + self.init_test_case() + + xxx = np.random.random((1, 2, 3, 4)).astype('float32') + + self.inputs = { + 'Loc': self.loc, + 'Conf': self.conf, + 'PriorBox': self.prior_box, + 'Label': (self.label, self.label_lod) + } + + self.attrs = { + 'class_num': self.classes_num, + 'overlap_threshold': self.overlap_threshold, + 'neg_pos_ratio': self.neg_pos_ratio, + 'neg_overlap': self.neg_overlap, + 'background_label_id': self.background_id + } + + self.outputs = { + 'Loss': self.loss, + 'InterCounter': self.inter_counter, + 'AllMatchIndices': self.all_match_indices, + 'AllNegIndices': self.all_neg_indices, + 'LocGTData': self.loc_gt_data, + 'ConfGTData': self.conf_gt_data, + 'LocDiff': self.loc_diff, + 'ConfProb': self.conf_prob + } + + def init_test_case(self): + self.input_num = 2 + self.classes_num = 3 + self.overlap_threshold = 0.3 + self.neg_pos_ratio = 3.0 + self.neg_overlap = 0.5 + self.background_id = 0 + + loc0 = [-0.768, -1.032, 0.046, 1.613, -0.205, 2.643, 2.771, 0.207] + loc1 = [-1.246, 0.096, -0.194, 0.554, -1.722, -2.082, -2.450, 1.673] + + conf0 = [-0.289, -2.602, 0.334, 0.718, -1.706, -2.971] + conf1 = [-3.235, -2.102, 1.241, -3.959, -1.846, 0.310] + + # dim = {2, 2, 2, 2} + loc0 = np.array(loc0).reshape((2, 4, 1, 1)).astype('float32') + loc1 = np.array(loc1).reshape((2, 4, 1, 1)).astype('float32') + + self.loc0 = loc0 + + # dim = {2, 3, 2, 2} + conf0 = np.array(conf0).reshape((2, 3, 1, 1)).astype('float32') + conf1 = np.array(conf1).reshape((2, 3, 1, 1)).astype('float32') + + self.loc = [('loc0', loc0), ('loc1', loc1)] + self.conf = [('conf0', conf0), ('conf1', conf1)] + + self.prior_box = [ + 0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.1, + 0.1, 0.2, 0.2 + ] + self.prior_box = np.array(self.prior_box).astype('float32') + + self.label_lod = [[0, 2, 4]] + self.label = [[1, 0.1, 0.1, 0.3, 0.3, 0], [1, 0.6, 0.6, 0.8, 0.8, 1], + [2, 0.3, 0.3, 0.6, 0.5, 0], [1, 0.7, 0.1, 0.9, 0.3, 0]] + + self.label = np.array(self.label).astype('float32') + + self.inter_counter = [2, 2, 4] + self.inter_counter = np.array(self.inter_counter).flatten().astype( + 'int64') + + self.all_match_indices = [[0, -1], [-1, 0]] + self.all_match_indices = np.array(self.all_match_indices).astype( + 'int64') + + self.all_neg_indices = [[1, -1], [0, -1]] + self.all_neg_indices = np.array(self.all_neg_indices).astype('int64') + + self.loc_gt_data = [[-2.5], [-2.5], [-3.466], [-3.466], [1.25], [0.], + [-1.44], [-3.466]] + self.loc_gt_data = np.array(self.loc_gt_data).astype('float32') + + self.conf_gt_data = [[1], [0], [2], [0]] + self.conf_gt_data = np.array(self.conf_gt_data).astype('int64') + + self.loc_diff = [[-0.768], [-1.032], [0.046], [1.613], [-1.722], + [-2.082], [-2.45], [1.673]] + self.loc_diff = np.array(self.loc_diff).astype('float32') + + self.conf_prob = [[0.33744144, 0.03339453, 0.62916404], + [0.01087106, 0.03375417, 0.95537478], + [0.01238802, 0.10248636, 0.88512564], + [0.89801782, 0.07953442, 0.02244774]] + self.conf_prob = np.array(self.conf_prob).astype('float32') + + self.loss = np.array([13.57]).flatten().astype('float32') + + def setUp(self): + self.op_type = "multi_box_loss" + self.set_data() + + def test_check_output(self): + self.check_output(atol=0.01) + + def test_check_grad(self): + self.check_grad(['loc0', 'loc1', 'conf0', 'conf1'], 'Loss') + + +if __name__ == '__main__': + unittest.main()