-
Notifications
You must be signed in to change notification settings - Fork 689
add postprocessing_masks gpu version #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,7 @@ def __init__(self, backend: Backend, backend_files: Sequence[str], | |
| super().__init__(deploy_cfg=deploy_cfg) | ||
| self.CLASSES = class_names | ||
| self.deploy_cfg = deploy_cfg | ||
| self.device = device | ||
| self._init_wrapper( | ||
| backend=backend, backend_files=backend_files, device=device) | ||
|
|
||
|
|
@@ -114,6 +115,7 @@ def postprocessing_masks(det_bboxes: np.ndarray, | |
| det_masks: np.ndarray, | ||
| img_w: int, | ||
| img_h: int, | ||
| device: str = 'cpu', | ||
| mask_thr_binary: float = 0.5) -> np.ndarray: | ||
| """Additional processing of masks. Resizes masks from [num_det, 28, 28] | ||
| to [num_det, img_w, img_h]. Analog of the 'mmdeploy.codebase.mmdet. | ||
|
|
@@ -138,17 +140,25 @@ def postprocessing_masks(det_bboxes: np.ndarray, | |
| return np.zeros((0, img_h, img_w)) | ||
|
|
||
| if isinstance(masks, np.ndarray): | ||
| masks = torch.tensor(masks) | ||
| bboxes = torch.tensor(bboxes) | ||
| masks = torch.tensor(masks, device=torch.device(device)) | ||
| bboxes = torch.tensor(bboxes, device=torch.device(device)) | ||
|
|
||
| result_masks = [] | ||
| for bbox, mask in zip(bboxes, masks): | ||
|
|
||
| x0_int, y0_int = 0, 0 | ||
| x1_int, y1_int = img_w, img_h | ||
|
|
||
| img_y = torch.arange(y0_int, y1_int, dtype=torch.float32) + 0.5 | ||
| img_x = torch.arange(x0_int, x1_int, dtype=torch.float32) + 0.5 | ||
| img_y = torch.arange( | ||
| y0_int, | ||
| y1_int, | ||
| dtype=torch.float32, | ||
| device=torch.device(device)) + 0.5 | ||
| img_x = torch.arange( | ||
| x0_int, | ||
| x1_int, | ||
| dtype=torch.float32, | ||
| device=torch.device(device)) + 0.5 | ||
| x0, y0, x1, y1 = bbox | ||
|
|
||
| img_y = (img_y - y0) / (y1 - y0) * 2 - 1 | ||
|
|
@@ -169,10 +179,8 @@ def postprocessing_masks(det_bboxes: np.ndarray, | |
| grid[None, :, :, :], | ||
| align_corners=False) | ||
|
|
||
| mask = img_masks | ||
| mask = (mask >= mask_thr_binary).to(dtype=torch.bool) | ||
| result_masks.append(mask.numpy()) | ||
| result_masks = np.concatenate(result_masks, axis=1) | ||
| result_masks.append(img_masks) | ||
| result_masks = torch.cat(result_masks, 1) | ||
| return result_masks.squeeze(0) | ||
|
|
||
| def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], | ||
|
|
@@ -206,6 +214,8 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], | |
| if isinstance(scale_factor, (list, tuple, np.ndarray)): | ||
| assert len(scale_factor) == 4 | ||
| scale_factor = np.array(scale_factor)[None, :] # [1,4] | ||
| scale_factor = torch.from_numpy(scale_factor).to( | ||
| device=torch.device(self.device)) | ||
| dets[:, :4] /= scale_factor | ||
|
|
||
| if 'border' in img_metas[i]: | ||
|
|
@@ -216,7 +226,7 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], | |
| y_off = img_metas[i]['border'][0] | ||
| dets[:, [0, 2]] -= x_off | ||
| dets[:, [1, 3]] -= y_off | ||
| dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you remove
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dets的数据类型是torch.tensor,dtype是float32,不需要astype
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.tensor dtype float32与bool 相乘会自动转换为float32
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. A better way is using import torch
from torch.profiler import profile, ProfilerActivity
def func_test(func,
data,
threshold,
num_test=10,
activate=ProfilerActivity.CPU):
# warmup
for _ in range(num_test):
out = func(data, threshold)
with profile(activities=[activate], record_shapes=True) as prof:
for _ in range(num_test):
out = func(data, threshold)
print(prof.key_averages().table(
sort_by="cpu_time_total", top_level_events_only=True, row_limit=30))
return out
def main():
data = torch.rand(1, 3, 224, 224)
threshold = 0.5
activate = ProfilerActivity.CPU
out0 = func_test(
lambda data, threshold: data * (data > threshold),
data,
threshold,
activate=activate)
out1 = func_test(
lambda data, threshold: torch.nn.functional.threshold(
data, threshold, 0),
data,
threshold,
activate=activate)
torch.testing.assert_allclose(out0, out1)
if __name__ == '__main__':
main()
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 牛皮。 |
||
| dets[:, :4] *= (dets[:, :4] > 0) | ||
|
|
||
| dets_results = bbox2result(dets, labels, len(self.CLASSES)) | ||
|
|
||
|
|
@@ -234,16 +244,14 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict], | |
| 'export_postprocess_mask', True) | ||
| if not export_postprocess_mask: | ||
| masks = End2EndModel.postprocessing_masks( | ||
| dets[:, :4], masks, ori_w, ori_h) | ||
| dets[:, :4], masks, ori_w, ori_h, self.device) | ||
| else: | ||
| masks = masks[:, :img_h, :img_w] | ||
| # avoid to resize masks with zero dim | ||
| if rescale and masks.shape[0] != 0: | ||
| masks = masks.astype(np.float32) | ||
| masks = torch.from_numpy(masks) | ||
| masks = torch.nn.functional.interpolate( | ||
| masks.unsqueeze(0), size=(ori_h, ori_w)) | ||
| masks = masks.squeeze(0).detach().numpy() | ||
| masks = masks.squeeze(0) | ||
| if masks.dtype != bool: | ||
| masks = masks >= 0.5 | ||
| segms_results = [[] for _ in range(len(self.CLASSES))] | ||
|
|
@@ -267,7 +275,6 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ | |
| """ | ||
| outputs = self.wrapper({self.input_name: imgs}) | ||
| outputs = self.wrapper.output_to_list(outputs) | ||
| outputs = [out.detach().cpu().numpy() for out in outputs] | ||
| return outputs | ||
|
|
||
| def show_result(self, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.