diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index b368d10972..51f2b3cc80 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -60,6 +60,7 @@ def __init__(self, backend: Backend, backend_files: Sequence[str], super().__init__(deploy_cfg=deploy_cfg) self.CLASSES = class_names self.deploy_cfg = deploy_cfg + self.device = device self._init_wrapper( backend=backend, backend_files=backend_files, device=device) @@ -114,6 +115,7 @@ def postprocessing_masks(det_bboxes: np.ndarray, det_masks: np.ndarray, img_w: int, img_h: int, + device: str = 'cpu', mask_thr_binary: float = 0.5) -> np.ndarray: """Additional processing of masks. Resizes masks from [num_det, 28, 28] to [num_det, img_w, img_h]. Analog of the 'mmdeploy.codebase.mmdet. @@ -138,8 +140,8 @@ def postprocessing_masks(det_bboxes: np.ndarray, return np.zeros((0, img_h, img_w)) if isinstance(masks, np.ndarray): - masks = torch.tensor(masks) - bboxes = torch.tensor(bboxes) + masks = torch.tensor(masks, device=torch.device(device)) + bboxes = torch.tensor(bboxes, device=torch.device(device)) result_masks = [] for bbox, mask in zip(bboxes, masks): @@ -147,8 +149,16 @@ def postprocessing_masks(det_bboxes: np.ndarray, x0_int, y0_int = 0, 0 x1_int, y1_int = img_w, img_h - img_y = torch.arange(y0_int, y1_int, dtype=torch.float32) + 0.5 - img_x = torch.arange(x0_int, x1_int, dtype=torch.float32) + 0.5 + img_y = torch.arange( + y0_int, + y1_int, + dtype=torch.float32, + device=torch.device(device)) + 0.5 + img_x = torch.arange( + x0_int, + x1_int, + dtype=torch.float32, + device=torch.device(device)) + 0.5 x0, y0, x1, y1 = bbox img_y = (img_y - y0) / (y1 - y0) * 2 - 1 @@ -169,10 +179,8 @@ def postprocessing_masks(det_bboxes: np.ndarray, grid[None, :, :, :], align_corners=False) - mask = img_masks - mask = (mask >= mask_thr_binary).to(dtype=torch.bool) - result_masks.append(mask.numpy()) - result_masks = np.concatenate(result_masks, axis=1) + result_masks.append(img_masks) + result_masks = torch.cat(result_masks, 1) return result_masks.squeeze(0) def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], @@ -206,6 +214,8 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], if isinstance(scale_factor, (list, tuple, np.ndarray)): assert len(scale_factor) == 4 scale_factor = np.array(scale_factor)[None, :] # [1,4] + scale_factor = torch.from_numpy(scale_factor).to( + device=torch.device(self.device)) dets[:, :4] /= scale_factor if 'border' in img_metas[i]: @@ -216,7 +226,7 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], y_off = img_metas[i]['border'][0] dets[:, [0, 2]] -= x_off dets[:, [1, 3]] -= y_off - dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype) + dets[:, :4] *= (dets[:, :4] > 0) dets_results = bbox2result(dets, labels, len(self.CLASSES)) @@ -234,16 +244,14 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], 'export_postprocess_mask', True) if not export_postprocess_mask: masks = End2EndModel.postprocessing_masks( - dets[:, :4], masks, ori_w, ori_h) + dets[:, :4], masks, ori_w, ori_h, self.device) else: masks = masks[:, :img_h, :img_w] # avoid to resize masks with zero dim if rescale and masks.shape[0] != 0: - masks = masks.astype(np.float32) - masks = torch.from_numpy(masks) masks = torch.nn.functional.interpolate( masks.unsqueeze(0), size=(ori_h, ori_w)) - masks = masks.squeeze(0).detach().numpy() + masks = masks.squeeze(0) if masks.dtype != bool: masks = masks >= 0.5 segms_results = [[] for _ in range(len(self.CLASSES))] @@ -267,7 +275,6 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ """ outputs = self.wrapper({self.input_name: imgs}) outputs = self.wrapper.output_to_list(outputs) - outputs = [out.detach().cpu().numpy() for out in outputs] return outputs def show_result(self,