Skip to content

Commit 987d48c

Browse files
q.yaoAllentDan
andauthored
[Enhancement] Update pad logic in detection heads (#168)
* pad with register * fix lint Co-authored-by: AllentDan <[email protected]>
1 parent 9553b8c commit 987d48c

File tree

7 files changed

+104
-36
lines changed

7 files changed

+104
-36
lines changed
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .core import * # noqa: F401,F403
33
from .deploy import (MMDetection, ObjectDetection, clip_bboxes,
4-
get_post_processing_params, pad_with_value)
4+
get_post_processing_params, pad_with_value,
5+
pad_with_value_if_necessary)
56
from .models import * # noqa: F401,F403
67

78
__all__ = [
89
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
9-
'MMDetection', 'ObjectDetection'
10+
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
1011
]
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .mmdetection import MMDetection
33
from .object_detection import ObjectDetection
4-
from .utils import clip_bboxes, get_post_processing_params, pad_with_value
4+
from .utils import (clip_bboxes, get_post_processing_params, pad_with_value,
5+
pad_with_value_if_necessary)
56

67
__all__ = [
78
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
8-
'MMDetection', 'ObjectDetection'
9+
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
910
]

mmdeploy/codebase/mmdet/deploy/utils.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mmdeploy.core import FUNCTION_REWRITER
99
from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker
10-
from mmdeploy.utils import load_config
10+
from mmdeploy.utils import Backend, load_config
1111

1212

1313
def get_post_processing_params(deploy_cfg: Union[str, mmcv.Config]):
@@ -127,3 +127,62 @@ def pad_with_value(x: Tensor,
127127
x_pad = x_pad.repeat(*repeat_size)
128128
x = torch.cat([x, x_pad], dim=pad_dim)
129129
return x
130+
131+
132+
def pad_with_value_if_necessary(x: Tensor,
133+
pad_dim: int,
134+
pad_size: int,
135+
pad_value: Optional[Any] = None):
136+
"""Pad a tensor with a value along some dim if necessary.
137+
138+
Args:
139+
x (Tensor): Input tensor.
140+
pad_dim (int): Along which dim to pad.
141+
pad_size (int): To which size to pad.
142+
pad_value (Any): Filled value for padding. Defaults to `None`.
143+
144+
Returns:
145+
Tensor: Padded tensor.
146+
"""
147+
return __pad_with_value_if_necessary(
148+
x, pad_dim, pad_size=pad_size, pad_value=pad_value)
149+
150+
151+
def __pad_with_value_if_necessary(x: Tensor,
152+
pad_dim: int,
153+
pad_size: int,
154+
pad_value: Optional[Any] = None):
155+
"""Pad a tensor with a value along some dim, do nothing on default.
156+
157+
Args:
158+
x (Tensor): Input tensor.
159+
pad_dim (int): Along which dim to pad.
160+
pad_size (int): To which size to pad.
161+
pad_value (Any): Filled value for padding. Defaults to `None`.
162+
163+
Returns:
164+
Tensor: Padded tensor.
165+
"""
166+
return x
167+
168+
169+
@FUNCTION_REWRITER.register_rewriter(
170+
'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary',
171+
backend=Backend.TENSORRT.value)
172+
def __pad_with_value_if_necessary__tensorrt(ctx,
173+
x: Tensor,
174+
pad_dim: int,
175+
pad_size: int,
176+
pad_value: Optional[Any] = None):
177+
"""Pad a tensor with a value along some dim.
178+
179+
Args:
180+
x (Tensor): Input tensor.
181+
pad_dim (int): Along which dim to pad.
182+
pad_size (int): To which size to pad.
183+
pad_value (Any): Filled value for padding. Defaults to `None`.
184+
185+
Returns:
186+
Tensor: Padded tensor.
187+
"""
188+
return pad_with_value(x, pad_dim, pad_size=pad_size, pad_value=pad_value)

mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from mmdet.core.bbox.transforms import distance2bbox
66

77
from mmdeploy.codebase.mmdet import (get_post_processing_params,
8-
multiclass_nms, pad_with_value)
8+
multiclass_nms,
9+
pad_with_value_if_necessary)
910
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
1011
from mmdeploy.core import FUNCTION_REWRITER
11-
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
12+
from mmdeploy.utils import Backend, is_dynamic_shape
1213

1314

1415
@FUNCTION_REWRITER.register_rewriter(
@@ -60,7 +61,6 @@ def base_dense_head__get_bbox(ctx,
6061
"""
6162
deploy_cfg = ctx.cfg
6263
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
63-
backend = get_backend(deploy_cfg)
6464
num_levels = len(cls_scores)
6565

6666
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
@@ -98,10 +98,8 @@ def base_dense_head__get_bbox(ctx,
9898
self.cls_out_channels)
9999
if self.use_sigmoid_cls:
100100
scores = scores.sigmoid()
101-
nms_pre_score = scores
102101
else:
103102
scores = scores.softmax(-1)
104-
nms_pre_score = scores
105103
if with_score_factors:
106104
score_factors = score_factors.permute(0, 2, 3,
107105
1).reshape(batch_size,
@@ -112,16 +110,16 @@ def base_dense_head__get_bbox(ctx,
112110
priors = priors.data
113111
priors = priors.expand(batch_size, -1, priors.size(-1))
114112
if pre_topk > 0:
113+
priors = pad_with_value_if_necessary(priors, 1, pre_topk)
114+
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
115+
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
116+
if with_score_factors:
117+
score_factors = pad_with_value_if_necessary(
118+
score_factors, 1, pre_topk, 0.)
119+
120+
nms_pre_score = scores
115121
if with_score_factors:
116122
nms_pre_score = nms_pre_score * score_factors
117-
if backend == Backend.TENSORRT:
118-
priors = pad_with_value(priors, 1, pre_topk)
119-
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
120-
scores = pad_with_value(scores, 1, pre_topk, 0.)
121-
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
122-
if with_score_factors:
123-
score_factors = pad_with_value(score_factors, 1, pre_topk,
124-
0.)
125123

126124
# Get maximum scores for foreground classes.
127125
if self.use_sigmoid_cls:
@@ -180,7 +178,7 @@ def base_dense_head__get_bbox(ctx,
180178
@FUNCTION_REWRITER.register_rewriter(
181179
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
182180
'.get_bboxes',
183-
backend='ncnn')
181+
backend=Backend.NCNN.value)
184182
def base_dense_head__get_bboxes__ncnn(ctx,
185183
self,
186184
cls_scores,

mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import torch
33

44
from mmdeploy.codebase.mmdet import (get_post_processing_params,
5-
multiclass_nms, pad_with_value)
5+
multiclass_nms,
6+
pad_with_value_if_necessary)
67
from mmdeploy.core import FUNCTION_REWRITER
7-
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
8+
from mmdeploy.utils import Backend, is_dynamic_shape
89

910

1011
@FUNCTION_REWRITER.register_rewriter(
@@ -95,13 +96,11 @@ def rpn_head__get_bboxes(ctx,
9596

9697
anchors = anchors.expand_as(bbox_pred)
9798

98-
backend = get_backend(deploy_cfg)
9999
# topk in tensorrt does not support shape<k
100100
# concate zero to enable topk,
101-
if backend == Backend.TENSORRT:
102-
scores = pad_with_value(scores, 1, pre_topk, 0.)
103-
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
104-
anchors = pad_with_value(anchors, 1, pre_topk)
101+
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
102+
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
103+
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)
105104

106105
if pre_topk > 0:
107106
_, topk_inds = scores.squeeze(2).topk(pre_topk)
@@ -145,7 +144,7 @@ def rpn_head__get_bboxes(ctx,
145144

146145

147146
@FUNCTION_REWRITER.register_rewriter(
148-
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend='ncnn')
147+
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value)
149148
def rpn_head__get_bboxes__ncnn(ctx,
150149
self,
151150
cls_scores,

mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import torch
44

55
from mmdeploy.codebase.mmdet import (get_post_processing_params,
6-
multiclass_nms, pad_with_value)
6+
multiclass_nms,
7+
pad_with_value_if_necessary)
78
from mmdeploy.core import FUNCTION_REWRITER
8-
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
9+
from mmdeploy.utils import Backend, is_dynamic_shape
910

1011

1112
@FUNCTION_REWRITER.register_rewriter(
@@ -90,13 +91,11 @@ def yolov3_head__get_bboxes(ctx,
9091
conf_pred = torch.sigmoid(pred_map[..., 4])
9192
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
9293
batch_size, -1, self.num_classes) # Cls pred one-hot.
93-
backend = get_backend(ctx.cfg)
9494
# topk in tensorrt does not support shape<k
9595
# concate zero to enable topk,
96-
if backend == Backend.TENSORRT:
97-
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
98-
conf_pred = pad_with_value(conf_pred, 1, pre_topk, 0.)
99-
cls_pred = pad_with_value(cls_pred, 1, pre_topk, 0.)
96+
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
97+
conf_pred = pad_with_value_if_necessary(conf_pred, 1, pre_topk, 0.)
98+
cls_pred = pad_with_value_if_necessary(cls_pred, 1, pre_topk, 0.)
10099

101100
if pre_topk > 0:
102101
_, topk_inds = conf_pred.topk(pre_topk)
@@ -161,7 +160,8 @@ def yolov3_head__get_bboxes(ctx,
161160

162161

163162
@FUNCTION_REWRITER.register_rewriter(
164-
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes', backend='ncnn')
163+
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes',
164+
backend=Backend.NCNN.value)
165165
def yolov3_head__get_bboxes__ncnn(ctx,
166166
self,
167167
pred_maps,

tests/test_codebase/test_mmdet/test_mmdet_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from mmdeploy.codebase import import_codebase
77
from mmdeploy.codebase.mmdet import (clip_bboxes, get_post_processing_params,
8-
pad_with_value)
8+
pad_with_value,
9+
pad_with_value_if_necessary)
910
from mmdeploy.utils import Codebase
1011

1112
import_codebase(Codebase.MMDET)
@@ -29,6 +30,15 @@ def test_pad_with_value():
2930
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)
3031

3132

33+
def test_pad_with_value_if_necessary():
34+
x = torch.rand(3, 2)
35+
padded_x = pad_with_value_if_necessary(
36+
x, pad_dim=1, pad_size=4, pad_value=0)
37+
assert np.allclose(
38+
padded_x.shape, torch.Size([3, 2]), rtol=1e-03, atol=1e-05)
39+
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)
40+
41+
3242
config_with_mmdet_params = mmcv.Config(
3343
dict(
3444
codebase_config=dict(

0 commit comments

Comments
 (0)