Skip to content

Commit e84615b

Browse files
Noplzqingqing01
authored andcommitted
Fix box coder op (#8647)
* fix ssd problems * fix box decoder op * fix dimension problem in detection tests * update detection doc * Update detection doc * Update detection doc * update detection doc * update detection doc
1 parent 9344e4e commit e84615b

File tree

6 files changed

+85
-59
lines changed

6 files changed

+85
-59
lines changed

paddle/fluid/operators/box_coder_op.cc

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,19 @@ class BoxCoderOp : public framework::OperatorWithKernel {
3737
"The rank of Input of PriorBoxVar must be 2");
3838
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]");
3939
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
40-
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
41-
"The rank of Input of TargetBox must be 2");
42-
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
43-
"The shape of TargetBox is [M, 4]");
4440

45-
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
41+
auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
42+
if (code_type == BoxCodeType::kEncodeCenterSize) {
43+
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
44+
"The rank of Input of TargetBox must be 2");
45+
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
46+
"The shape of TargetBox is [M, 4]");
47+
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
48+
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
49+
"The rank of Input of TargetBox must be 3");
50+
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
51+
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
52+
}
4653

4754
ctx->SetOutputDim(
4855
"OutputBox",
@@ -70,25 +77,28 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
7077
"of variance.");
7178
AddInput(
7279
"TargetBox",
73-
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
74-
"[N, 4], each box is represented as [xmin, ymin, xmax, ymax], "
75-
"[xmin, ymin] is the left top coordinate of the box if the input "
76-
"is image feature map, they are close to the origin of the coordinate "
77-
"system. [xmax, ymax] is the right bottom coordinate of the box. "
78-
"This tensor can contain LoD information to represent a batch "
79-
"of inputs. One instance of this batch can contain different "
80-
"numbers of entities.");
80+
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
81+
"[N, 4] when code_type is 'encode_center_size'. This input also can "
82+
"be a 3-D Tensor with shape [N, M, 4] when code_type is "
83+
"'decode_center_size'. [N, 4], each box is represented as "
84+
"[xmin, ymin, xmax, ymax], [xmin, ymin] is the left top coordinate "
85+
"of the box if the input is image feature map, they are close to "
86+
"the origin of the coordinate system. [xmax, ymax] is the right "
87+
"bottom coordinate of the box. This tensor can contain LoD "
88+
"information to represent a batch of inputs. One instance of this "
89+
"batch can contain different numbers of entities.");
8190
AddAttr<std::string>("code_type",
8291
"(string, default encode_center_size) "
8392
"the code type used with the target box")
8493
.SetDefault("encode_center_size")
8594
.InEnum({"encode_center_size", "decode_center_size"});
86-
AddOutput(
87-
"OutputBox",
88-
"(LoDTensor or Tensor) "
89-
"(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] "
90-
"representing the result of N target boxes encoded/decoded with "
91-
"M Prior boxes and variances.");
95+
AddOutput("OutputBox",
96+
"(LoDTensor or Tensor) "
97+
"When code_type is 'encode_center_size', the output tensor of "
98+
"box_coder_op with shape [N, M, 4] representing the result of N "
99+
"target boxes encoded with M Prior boxes and variances. When "
100+
"code_type is 'decode_center_size', N represents the batch size "
101+
"and M represents the number of deocded boxes.");
92102

93103
AddComment(R"DOC(
94104
Bounding Box Coder Operator.

paddle/fluid/operators/box_coder_op.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
6666
T* output) {
6767
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
6868
if (idx < row * col) {
69-
const int row_idx = idx / col;
7069
const int col_idx = idx % col;
7170
T prior_box_width =
7271
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
@@ -79,17 +78,16 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
7978
2;
8079

8180
T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
82-
target_box_data[row_idx * len + 2]) *
81+
target_box_data[idx * len + 2]) *
8382
prior_box_width;
8483
T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
85-
target_box_data[row_idx * len + 3]) *
84+
target_box_data[idx * len + 3]) *
8685
prior_box_height;
8786
T target_box_center_x = prior_box_var_data[col_idx * len] *
88-
target_box_data[row_idx * len] *
89-
prior_box_width +
87+
target_box_data[idx * len] * prior_box_width +
9088
prior_box_center_x;
9189
T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
92-
target_box_data[row_idx * len + 1] *
90+
target_box_data[idx * len + 1] *
9391
prior_box_height +
9492
prior_box_center_y;
9593

paddle/fluid/operators/box_coder_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
8989

9090
for (int64_t i = 0; i < row; ++i) {
9191
for (int64_t j = 0; j < col; ++j) {
92+
size_t offset = i * col * len + j * len;
9293
T prior_box_width =
9394
prior_box_data[j * len + 2] - prior_box_data[j * len];
9495
T prior_box_height =
@@ -99,20 +100,19 @@ class BoxCoderKernel : public framework::OpKernel<T> {
99100
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
100101

101102
T target_box_center_x = prior_box_var_data[j * len] *
102-
target_box_data[i * len] * prior_box_width +
103+
target_box_data[offset] * prior_box_width +
103104
prior_box_center_x;
104105
T target_box_center_y = prior_box_var_data[j * len + 1] *
105-
target_box_data[i * len + 1] *
106+
target_box_data[offset + 1] *
106107
prior_box_height +
107108
prior_box_center_y;
108109
T target_box_width = std::exp(prior_box_var_data[j * len + 2] *
109-
target_box_data[i * len + 2]) *
110+
target_box_data[offset + 2]) *
110111
prior_box_width;
111112
T target_box_height = std::exp(prior_box_var_data[j * len + 3] *
112-
target_box_data[i * len + 3]) *
113+
target_box_data[offset + 3]) *
113114
prior_box_height;
114115

115-
size_t offset = i * col * len + j * len;
116116
output[offset] = target_box_center_x - target_box_width / 2;
117117
output[offset + 1] = target_box_center_y - target_box_height / 2;
118118
output[offset + 2] = target_box_center_x + target_box_width / 2;

python/paddle/fluid/layers/detection.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
globals()[_OP] = generate_layer_fn(_OP)
4444

4545

46-
def detection_output(scores,
47-
loc,
46+
def detection_output(loc,
47+
scores,
4848
prior_box,
4949
prior_box_var,
5050
background_label=0,
@@ -61,14 +61,14 @@ def detection_output(scores,
6161
be zero if there is no valid bounding box.
6262
6363
Args:
64-
scores(Variable): A 3-D Tensor with shape [N, C, M] represents the
65-
predicted confidence predictions. N is the batch size, C is the
66-
class number, M is number of bounding boxes. For each category
67-
there are total M scores which corresponding M bounding boxes.
6864
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the
6965
predicted locations of M bounding bboxes. N is the batch size,
7066
and each bounding box has four coordinate values and the layout
7167
is [xmin, ymin, xmax, ymax].
68+
scores(Variable): A 3-D Tensor with shape [N, M, C] represents the
69+
predicted confidence predictions. N is the batch size, C is the
70+
class number, M is number of bounding boxes. For each category
71+
there are total M scores which corresponding M bounding boxes.
7272
prior_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
7373
each box is represented as [xmin, ymin, xmax, ymax],
7474
[xmin, ymin] is the left top coordinate of the anchor box,
@@ -100,7 +100,7 @@ class number, M is number of bounding boxes. For each category
100100
append_batch_size=False, dtype='float32')
101101
pbv = layers.data(name='prior_box_var', shape=[10, 4],
102102
append_batch_size=False, dtype='float32')
103-
loc = layers.data(name='target_box', shape=[21, 4],
103+
loc = layers.data(name='target_box', shape=[2, 21, 4],
104104
append_batch_size=False, dtype='float32')
105105
scores = layers.data(name='scores', shape=[2, 21, 10],
106106
append_batch_size=False, dtype='float32')
@@ -109,7 +109,6 @@ class number, M is number of bounding boxes. For each category
109109
prior_box=pb,
110110
prior_box_var=pbv)
111111
"""
112-
113112
helper = LayerHelper("detection_output", **locals())
114113
decoded_box = box_coder(
115114
prior_box=prior_box,
@@ -118,6 +117,7 @@ class number, M is number of bounding boxes. For each category
118117
code_type='decode_center_size')
119118

120119
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
120+
scores = nn.transpose(scores, perm=[0, 2, 1])
121121
helper.append_op(
122122
type="multiclass_nms",
123123
inputs={'Scores': scores,
@@ -595,12 +595,13 @@ def multi_box_head(inputs,
595595
name(str): Name of the prior box layer. Default: None.
596596
597597
Returns:
598-
mbox_loc(list): The predicted boxes' location of the inputs.
599-
The layout of each element is [N, H, W, Priors]. Priors
600-
is the number of predicted boxof each position of each input.
601-
mbox_conf(list): The predicted boxes' confidence of the inputs.
602-
The layout of each element is [N, H, W, Priors]. Priors
603-
is the number of predicted box of each position of each input.
598+
mbox_loc(Variable): The predicted boxes' location of the inputs.
599+
The layout is [N, H*W*Priors, 4]. where Priors
600+
is the number of predicted boxes each position of each input.
601+
mbox_conf(Variable): The predicted boxes' confidence of the inputs.
602+
The layout is [N, H*W*Priors, C]. where Priors
603+
is the number of predicted boxes each position of each input
604+
and C is the number of Classes.
604605
boxes(Variable): the output prior boxes of PriorBox.
605606
The layout is [num_priors, 4]. num_priors is the total
606607
box count of each position of inputs.
@@ -751,7 +752,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
751752
num_boxes = box.shape[2]
752753

753754
# get box_loc
754-
num_loc_output = num_boxes * num_classes * 4
755+
num_loc_output = num_boxes * 4
755756
mbox_loc = nn.conv2d(
756757
input=input,
757758
num_filters=num_loc_output,
@@ -760,7 +761,12 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
760761
stride=stride)
761762

762763
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
763-
mbox_locs.append(mbox_loc)
764+
new_shape = [
765+
mbox_loc.shape[0],
766+
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
767+
]
768+
mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape)
769+
mbox_locs.append(mbox_loc_flatten)
764770

765771
# get conf_loc
766772
num_conf_output = num_boxes * num_classes
@@ -771,11 +777,18 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
771777
padding=pad,
772778
stride=stride)
773779
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
774-
mbox_confs.append(conf_loc)
780+
new_shape = [
781+
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
782+
conf_loc.shape[3] / num_classes, num_classes
783+
]
784+
conf_loc_flatten = ops.reshape(conf_loc, shape=new_shape)
785+
mbox_confs.append(conf_loc_flatten)
775786

776787
if len(box_results) == 1:
777788
box = box_results[0]
778789
var = var_results[0]
790+
mbox_locs_concat = mbox_locs[0]
791+
mbox_confs_concat = mbox_confs[0]
779792
else:
780793
reshaped_boxes = []
781794
reshaped_vars = []
@@ -785,5 +798,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
785798

786799
box = tensor.concat(reshaped_boxes)
787800
var = tensor.concat(reshaped_vars)
801+
mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
802+
mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
788803

789-
return mbox_locs, mbox_confs, box, var
804+
return mbox_locs_concat, mbox_confs_concat, box, var

python/paddle/fluid/tests/test_detection.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def test_detection_output(self):
3535
dtype='float32')
3636
loc = layers.data(
3737
name='target_box',
38-
shape=[20, 4],
38+
shape=[2, 10, 4],
3939
append_batch_size=False,
4040
dtype='float32')
4141
scores = layers.data(
4242
name='scores',
43-
shape=[2, 20, 10],
43+
shape=[2, 10, 20],
4444
append_batch_size=False,
4545
dtype='float32')
4646
out = layers.detection_output(
@@ -117,9 +117,7 @@ def test_multi_box_head(self):
117117
assert len(box.shape) == 2
118118
assert box.shape == var.shape
119119
assert box.shape[1] == 4
120-
121-
for loc, conf in zip(mbox_locs, mbox_confs):
122-
assert loc.shape[1:3] == conf.shape[1:3]
120+
assert mbox_locs.shape[1] == mbox_confs.shape[1]
123121

124122
def multi_box_head_output(self, data_shape):
125123
images = fluid.layers.data(

python/paddle/fluid/tests/unittests/test_box_coder_op.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
5151
prior_box_var[:,:,3]
5252

5353
elif (code_type == "DecodeCenterSize"):
54-
target_box = target_box.reshape(target_box.shape[0], 1,
55-
target_box.shape[1])
5654
target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \
5755
prior_box_width + prior_box_x
5856
target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \
@@ -61,6 +59,7 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
6159
prior_box_width
6260
target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \
6361
prior_box_height
62+
6463
output_box[:, :, 0] = target_box_x - target_box_width / 2
6564
output_box[:, :, 1] = target_box_y - target_box_height / 2
6665
output_box[:, :, 2] = target_box_x + target_box_width / 2
@@ -72,8 +71,14 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type):
7271
m = prior_box.shape[0]
7372
output_box = np.zeros((n, m, 4), dtype=np.float32)
7473
for i in range(len(lod) - 1):
75-
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, prior_box_var,
76-
output_box[lod[i]:lod[i + 1], :, :], code_type)
74+
if (code_type == "EncodeCenterSize"):
75+
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box,
76+
prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
77+
code_type)
78+
elif (code_type == "DecodeCenterSize"):
79+
box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box,
80+
prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
81+
code_type)
7782
return output_box
7883

7984

@@ -83,10 +88,10 @@ def test_check_output(self):
8388

8489
def setUp(self):
8590
self.op_type = "box_coder"
86-
lod = [[0, 20]]
91+
lod = [[0, 1, 2, 3, 4, 5]]
8792
prior_box = np.random.random((10, 4)).astype('float32')
8893
prior_box_var = np.random.random((10, 4)).astype('float32')
89-
target_box = np.random.random((20, 4)).astype('float32')
94+
target_box = np.random.random((5, 10, 4)).astype('float32')
9095
code_type = "DecodeCenterSize"
9196
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
9297
lod[0], code_type)

0 commit comments

Comments
 (0)