Skip to content

Commit a824da9

Browse files
authored
Merge pull request #6588 from wanghaox/detection_map
detection map evaluator for SSD
2 parents e9d3099 + 91a2188 commit a824da9

File tree

3 files changed

+900
-0
lines changed

3 files changed

+900
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/detection_map_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
class DetectionMAPOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext* ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
28+
"Input(DetectRes) of DetectionMAPOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasInput("Label"),
30+
"Input(Label) of DetectionMAPOp should not be null.");
31+
PADDLE_ENFORCE(
32+
ctx->HasOutput("AccumPosCount"),
33+
"Output(AccumPosCount) of DetectionMAPOp should not be null.");
34+
PADDLE_ENFORCE(
35+
ctx->HasOutput("AccumTruePos"),
36+
"Output(AccumTruePos) of DetectionMAPOp should not be null.");
37+
PADDLE_ENFORCE(
38+
ctx->HasOutput("AccumFalsePos"),
39+
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
40+
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
41+
"Output(MAP) of DetectionMAPOp should not be null.");
42+
43+
auto det_dims = ctx->GetInputDim("DetectRes");
44+
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
45+
"The rank of Input(DetectRes) must be 2, "
46+
"the shape is [N, 6].");
47+
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
48+
"The shape is of Input(DetectRes) [N, 6].");
49+
auto label_dims = ctx->GetInputDim("Label");
50+
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
51+
"The rank of Input(Label) must be 2, "
52+
"the shape is [N, 6].");
53+
PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
54+
"The shape is of Input(Label) [N, 6].");
55+
56+
if (ctx->HasInput("PosCount")) {
57+
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
58+
"Input(TruePos) of DetectionMAPOp should not be null when "
59+
"Input(TruePos) is not null.");
60+
PADDLE_ENFORCE(
61+
ctx->HasInput("FalsePos"),
62+
"Input(FalsePos) of DetectionMAPOp should not be null when "
63+
"Input(FalsePos) is not null.");
64+
}
65+
66+
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
67+
}
68+
69+
protected:
70+
framework::OpKernelType GetExpectedKernelType(
71+
const framework::ExecutionContext& ctx) const override {
72+
return framework::OpKernelType(
73+
framework::ToDataType(
74+
ctx.Input<framework::Tensor>("DetectRes")->type()),
75+
ctx.device_context());
76+
}
77+
};
78+
79+
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
80+
public:
81+
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
82+
: OpProtoAndCheckerMaker(proto, op_checker) {
83+
AddInput("DetectRes",
84+
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
85+
"detections. Each row has 6 values: "
86+
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
87+
"number of detect results in this mini-batch. For each instance, "
88+
"the offsets in first dimension are called LoD, the number of "
89+
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
90+
"no detected data.");
91+
AddInput("Label",
92+
"(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the"
93+
"Labeled ground-truth data. Each row has 6 values: "
94+
"[label, is_difficult, xmin, ymin, xmax, ymax], N is the total "
95+
"number of ground-truth data in this mini-batch. For each "
96+
"instance, the offsets in first dimension are called LoD, "
97+
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
98+
"means there is no ground-truth data.");
99+
AddInput("PosCount",
100+
"(Tensor) A tensor with shape [Ncls, 1], store the "
101+
"input positive example count of each class, Ncls is the count of "
102+
"input classification. "
103+
"This input is used to pass the AccumPosCount generated by the "
104+
"previous mini-batch when the multi mini-batches cumulative "
105+
"calculation carried out. "
106+
"When the input(PosCount) is empty, the cumulative "
107+
"calculation is not carried out, and only the results of the "
108+
"current mini-batch are calculated.")
109+
.AsDispensable();
110+
AddInput("TruePos",
111+
"(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the "
112+
"input true positive example of each class."
113+
"This input is used to pass the AccumTruePos generated by the "
114+
"previous mini-batch when the multi mini-batches cumulative "
115+
"calculation carried out. ")
116+
.AsDispensable();
117+
AddInput("FalsePos",
118+
"(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the "
119+
"input false positive example of each class."
120+
"This input is used to pass the AccumFalsePos generated by the "
121+
"previous mini-batch when the multi mini-batches cumulative "
122+
"calculation carried out. ")
123+
.AsDispensable();
124+
AddOutput("AccumPosCount",
125+
"(Tensor) A tensor with shape [Ncls, 1], store the "
126+
"positive example count of each class. It combines the input "
127+
"input(PosCount) and the positive example count computed from "
128+
"input(Detection) and input(Label).");
129+
AddOutput("AccumTruePos",
130+
"(LoDTensor) A LoDTensor with shape [Ntp', 2], store the "
131+
"true positive example of each class. It combines the "
132+
"input(TruePos) and the true positive examples computed from "
133+
"input(Detection) and input(Label).");
134+
AddOutput("AccumFalsePos",
135+
"(LoDTensor) A LoDTensor with shape [Nfp', 2], store the "
136+
"false positive example of each class. It combines the "
137+
"input(FalsePos) and the false positive examples computed from "
138+
"input(Detection) and input(Label).");
139+
AddOutput("MAP",
140+
"(Tensor) A tensor with shape [1], store the mAP evaluate "
141+
"result of the detection.");
142+
143+
AddAttr<float>(
144+
"overlap_threshold",
145+
"(float) "
146+
"The lower bound jaccard overlap threshold of detection output and "
147+
"ground-truth data.")
148+
.SetDefault(.3f);
149+
AddAttr<bool>("evaluate_difficult",
150+
"(bool, default true) "
151+
"Switch to control whether the difficult data is evaluated.")
152+
.SetDefault(true);
153+
AddAttr<std::string>("ap_type",
154+
"(string, default 'integral') "
155+
"The AP algorithm type, 'integral' or '11point'.")
156+
.SetDefault("integral")
157+
.InEnum({"integral", "11point"})
158+
.AddCustomChecker([](const std::string& ap_type) {
159+
PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone,
160+
"The ap_type should be 'integral' or '11point.");
161+
});
162+
AddComment(R"DOC(
163+
Detection mAP evaluate operator.
164+
The general steps are as follows. First, calculate the true positive and
165+
false positive according to the input of detection and labels, then
166+
calculate the mAP evaluate value.
167+
Supporting '11 point' and 'integral' mAP algorithm. Please get more information
168+
from the following articles:
169+
https://sanchom.wordpress.com/tag/average-precision/
170+
https://arxiv.org/abs/1512.02325
171+
172+
)DOC");
173+
}
174+
};
175+
176+
} // namespace operators
177+
} // namespace paddle
178+
179+
namespace ops = paddle::operators;
180+
REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp,
181+
ops::DetectionMAPOpMaker);
182+
REGISTER_OP_CPU_KERNEL(
183+
detection_map, ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, float>,
184+
ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, double>);

0 commit comments

Comments
 (0)