Skip to content

Commit 18c9505

Browse files
xizihadoop-basecv
authored andcommitted
add postprocessing_masks gpu version (open-mmlab#276)
* add postprocessing_masks gpu version * default device cpu * pre-commit fix Co-authored-by: hadoop-basecv <[email protected]>
1 parent 2f3de01 commit 18c9505

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

mmdeploy/codebase/mmdet/deploy/object_detection_model.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self, backend: Backend, backend_files: Sequence[str],
6060
super().__init__(deploy_cfg=deploy_cfg)
6161
self.CLASSES = class_names
6262
self.deploy_cfg = deploy_cfg
63+
self.device = device
6364
self._init_wrapper(
6465
backend=backend, backend_files=backend_files, device=device)
6566

@@ -114,6 +115,7 @@ def postprocessing_masks(det_bboxes: np.ndarray,
114115
det_masks: np.ndarray,
115116
img_w: int,
116117
img_h: int,
118+
device: str = 'cpu',
117119
mask_thr_binary: float = 0.5) -> np.ndarray:
118120
"""Additional processing of masks. Resizes masks from [num_det, 28, 28]
119121
to [num_det, img_w, img_h]. Analog of the 'mmdeploy.codebase.mmdet.
@@ -138,17 +140,25 @@ def postprocessing_masks(det_bboxes: np.ndarray,
138140
return np.zeros((0, img_h, img_w))
139141

140142
if isinstance(masks, np.ndarray):
141-
masks = torch.tensor(masks)
142-
bboxes = torch.tensor(bboxes)
143+
masks = torch.tensor(masks, device=torch.device(device))
144+
bboxes = torch.tensor(bboxes, device=torch.device(device))
143145

144146
result_masks = []
145147
for bbox, mask in zip(bboxes, masks):
146148

147149
x0_int, y0_int = 0, 0
148150
x1_int, y1_int = img_w, img_h
149151

150-
img_y = torch.arange(y0_int, y1_int, dtype=torch.float32) + 0.5
151-
img_x = torch.arange(x0_int, x1_int, dtype=torch.float32) + 0.5
152+
img_y = torch.arange(
153+
y0_int,
154+
y1_int,
155+
dtype=torch.float32,
156+
device=torch.device(device)) + 0.5
157+
img_x = torch.arange(
158+
x0_int,
159+
x1_int,
160+
dtype=torch.float32,
161+
device=torch.device(device)) + 0.5
152162
x0, y0, x1, y1 = bbox
153163

154164
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
@@ -169,10 +179,8 @@ def postprocessing_masks(det_bboxes: np.ndarray,
169179
grid[None, :, :, :],
170180
align_corners=False)
171181

172-
mask = img_masks
173-
mask = (mask >= mask_thr_binary).to(dtype=torch.bool)
174-
result_masks.append(mask.numpy())
175-
result_masks = np.concatenate(result_masks, axis=1)
182+
result_masks.append(img_masks)
183+
result_masks = torch.cat(result_masks, 1)
176184
return result_masks.squeeze(0)
177185

178186
def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
@@ -206,6 +214,8 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
206214
if isinstance(scale_factor, (list, tuple, np.ndarray)):
207215
assert len(scale_factor) == 4
208216
scale_factor = np.array(scale_factor)[None, :] # [1,4]
217+
scale_factor = torch.from_numpy(scale_factor).to(
218+
device=torch.device(self.device))
209219
dets[:, :4] /= scale_factor
210220

211221
if 'border' in img_metas[i]:
@@ -216,7 +226,7 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
216226
y_off = img_metas[i]['border'][0]
217227
dets[:, [0, 2]] -= x_off
218228
dets[:, [1, 3]] -= y_off
219-
dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype)
229+
dets[:, :4] *= (dets[:, :4] > 0)
220230

221231
dets_results = bbox2result(dets, labels, len(self.CLASSES))
222232

@@ -234,16 +244,14 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
234244
'export_postprocess_mask', True)
235245
if not export_postprocess_mask:
236246
masks = End2EndModel.postprocessing_masks(
237-
dets[:, :4], masks, ori_w, ori_h)
247+
dets[:, :4], masks, ori_w, ori_h, self.device)
238248
else:
239249
masks = masks[:, :img_h, :img_w]
240250
# avoid to resize masks with zero dim
241251
if rescale and masks.shape[0] != 0:
242-
masks = masks.astype(np.float32)
243-
masks = torch.from_numpy(masks)
244252
masks = torch.nn.functional.interpolate(
245253
masks.unsqueeze(0), size=(ori_h, ori_w))
246-
masks = masks.squeeze(0).detach().numpy()
254+
masks = masks.squeeze(0)
247255
if masks.dtype != bool:
248256
masks = masks >= 0.5
249257
segms_results = [[] for _ in range(len(self.CLASSES))]
@@ -267,7 +275,6 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
267275
"""
268276
outputs = self.wrapper({self.input_name: imgs})
269277
outputs = self.wrapper.output_to_list(outputs)
270-
outputs = [out.detach().cpu().numpy() for out in outputs]
271278
return outputs
272279

273280
def show_result(self,

0 commit comments

Comments
 (0)