Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions paddle/fluid/operators/detection/yolo_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand All @@ -31,19 +32,44 @@ class YoloBoxOp : public framework::OperatorWithKernel {
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num");
auto iou_aware = ctx->Attrs().Get<bool>("iou_aware");
auto iou_aware_factor = ctx->Attrs().Get<float>("iou_aware_factor");

PADDLE_ENFORCE_EQ(dim_x.size(), 4, platform::errors::InvalidArgument(
"Input(X) should be a 4-D tensor."
"But received X dimension(%s)",
dim_x.size()));
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (5 + class_num),
platform::errors::InvalidArgument(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s).",
dim_x[1], anchor_num * (5 + class_num)));
if (iou_aware) {
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (6 + class_num),
platform::errors::InvalidArgument(
"Input(X) dim[1] should be equal to (anchor_mask_number * (6 "
"+ class_num)) while iou_aware is true."
"But received dim[1](%s) != (anchor_mask_number * "
"(6+class_num)(%s).",
dim_x[1], anchor_num * (6 + class_num)));
PADDLE_ENFORCE_GE(
iou_aware_factor, 0,
platform::errors::InvalidArgument(
"Attr(iou_aware_factor) should greater than or equal to 0."
"But received iou_aware_factor (%s)",
iou_aware_factor));
PADDLE_ENFORCE_LE(
iou_aware_factor, 1,
platform::errors::InvalidArgument(
"Attr(iou_aware_factor) should less than or equal to 1."
"But received iou_aware_factor (%s)",
iou_aware_factor));
} else {
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (5 + class_num),
platform::errors::InvalidArgument(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s).",
dim_x[1], anchor_num * (5 + class_num)));
}
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
platform::errors::InvalidArgument(
"Input(ImgSize) should be a 2-D tensor."
Expand Down Expand Up @@ -140,14 +166,19 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Scale the center point of decoded bounding "
"box. Default 1.0")
.SetDefault(1.);
AddAttr<bool>("iou_aware", "Whether use iou aware. Default false.")
.SetDefault(false);
AddAttr<float>("iou_aware_factor", "iou aware factor. Default 0.5.")
.SetDefault(0.5);
AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network.

The output of previous network is in shape [N, C, H, W], while H and W
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object
dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false,
otherwise C should be equal to S * (6 + class_num). class_num is the object
category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor
Expand Down Expand Up @@ -183,6 +214,15 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
score_{pred} = score_{conf} * score_{class}
$$

where the confidence scores follow the formula bellow

.. math::

score_{conf} = \begin{case}
obj, \text{if } iou_aware == flase \\
obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise}
\end{case}

)DOC");
}
};
Expand All @@ -197,3 +237,12 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>);

REGISTER_OP_VERSION(yolo_box)
.AddCheckpoint(
R"ROC(
Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("iou_aware", "Whether use iou aware", false)
.NewAttr("iou_aware_factor", "iou aware factor", 0.5f));
25 changes: 17 additions & 8 deletions paddle/fluid/operators/detection/yolo_box_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
const int w, const int an_num, const int class_num,
const int box_num, int input_size_h,
int input_size_w, bool clip_bbox, const float scale,
const float bias) {
const float bias, bool iou_aware,
const float iou_aware_factor) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
Expand All @@ -43,23 +44,29 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];

int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4,
iou_aware);
T conf = sigmoid<T>(input[obj_idx]);
if (iou_aware) {
int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num);
T iou = sigmoid<T>(input[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) {
continue;
}

int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0,
iou_aware);
GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
input_size_w, box_idx, grid_num, img_height, img_width, scale,
bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);

int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num,
5, iou_aware);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
Expand All @@ -80,6 +87,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
bool iou_aware = ctx.Attr<bool>("iou_aware");
float iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);

Expand Down Expand Up @@ -115,7 +124,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
input_size_w, clip_bbox, scale, bias);
input_size_w, clip_bbox, scale, bias, iou_aware, iou_aware_factor);
}
};

Expand Down
37 changes: 29 additions & 8 deletions paddle/fluid/operators/detection/yolo_box_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
Expand Down Expand Up @@ -43,8 +44,19 @@ HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,

HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
int entry, bool iou_aware) {
if (iou_aware) {
return (batch * an_num + an_idx) * an_stride +
(batch * an_num + an_num + entry) * stride + hw_idx;
} else {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
}

HOSTDEVICE inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride) {
return batch * an_num * an_stride + (batch * an_num + an_idx) * stride +
hw_idx;
}

template <typename T>
Expand Down Expand Up @@ -92,6 +104,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
bool iou_aware = ctx.Attr<bool>("iou_aware");
float iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);

Expand Down Expand Up @@ -127,24 +141,31 @@ class YoloBoxKernel : public framework::OpKernel<T> {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4);
int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
stride, 4, iou_aware);
T conf = sigmoid<T>(input_data[obj_idx]);
if (iou_aware) {
int iou_idx =
GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride);
T iou = sigmoid<T>(input_data[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) {
continue;
}

int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
stride, 0, iou_aware);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, w,
input_size_h, input_size_w, box_idx, stride,
img_height, img_width, scale, bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox);

int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
stride, 5, iou_aware);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
class_num, conf, stride);
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/fluid/layers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,9 @@ def yolo_box(x,
downsample_ratio,
clip_bbox=True,
name=None,
scale_x_y=1.):
scale_x_y=1.,
iou_aware=False,
iou_aware_factor=0.5):
"""

${comment}
Expand All @@ -1156,6 +1158,8 @@ def yolo_box(x,
name (string): The default value is None. Normally there is no need
for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
iou_aware (bool): ${iou_aware_comment}
iou_aware_factor (float): ${iou_aware_factor_comment}

Returns:
Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
Expand Down Expand Up @@ -1204,6 +1208,8 @@ def yolo_box(x,
"downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox,
"scale_x_y": scale_x_y,
"iou_aware": iou_aware,
"iou_aware_factor": iou_aware_factor
}

helper.append_op(
Expand Down
Loading