@@ -60,6 +60,7 @@ def __init__(self, backend: Backend, backend_files: Sequence[str],
6060 super ().__init__ (deploy_cfg = deploy_cfg )
6161 self .CLASSES = class_names
6262 self .deploy_cfg = deploy_cfg
63+ self .device = device
6364 self ._init_wrapper (
6465 backend = backend , backend_files = backend_files , device = device )
6566
@@ -114,6 +115,7 @@ def postprocessing_masks(det_bboxes: np.ndarray,
114115 det_masks : np .ndarray ,
115116 img_w : int ,
116117 img_h : int ,
118+ device : str = 'cpu' ,
117119 mask_thr_binary : float = 0.5 ) -> np .ndarray :
118120 """Additional processing of masks. Resizes masks from [num_det, 28, 28]
119121 to [num_det, img_w, img_h]. Analog of the 'mmdeploy.codebase.mmdet.
@@ -138,17 +140,25 @@ def postprocessing_masks(det_bboxes: np.ndarray,
138140 return np .zeros ((0 , img_h , img_w ))
139141
140142 if isinstance (masks , np .ndarray ):
141- masks = torch .tensor (masks )
142- bboxes = torch .tensor (bboxes )
143+ masks = torch .tensor (masks , device = torch . device ( device ) )
144+ bboxes = torch .tensor (bboxes , device = torch . device ( device ) )
143145
144146 result_masks = []
145147 for bbox , mask in zip (bboxes , masks ):
146148
147149 x0_int , y0_int = 0 , 0
148150 x1_int , y1_int = img_w , img_h
149151
150- img_y = torch .arange (y0_int , y1_int , dtype = torch .float32 ) + 0.5
151- img_x = torch .arange (x0_int , x1_int , dtype = torch .float32 ) + 0.5
152+ img_y = torch .arange (
153+ y0_int ,
154+ y1_int ,
155+ dtype = torch .float32 ,
156+ device = torch .device (device )) + 0.5
157+ img_x = torch .arange (
158+ x0_int ,
159+ x1_int ,
160+ dtype = torch .float32 ,
161+ device = torch .device (device )) + 0.5
152162 x0 , y0 , x1 , y1 = bbox
153163
154164 img_y = (img_y - y0 ) / (y1 - y0 ) * 2 - 1
@@ -169,10 +179,8 @@ def postprocessing_masks(det_bboxes: np.ndarray,
169179 grid [None , :, :, :],
170180 align_corners = False )
171181
172- mask = img_masks
173- mask = (mask >= mask_thr_binary ).to (dtype = torch .bool )
174- result_masks .append (mask .numpy ())
175- result_masks = np .concatenate (result_masks , axis = 1 )
182+ result_masks .append (img_masks )
183+ result_masks = torch .cat (result_masks , 1 )
176184 return result_masks .squeeze (0 )
177185
178186 def forward (self , img : Sequence [torch .Tensor ], img_metas : Sequence [dict ],
@@ -206,6 +214,8 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
206214 if isinstance (scale_factor , (list , tuple , np .ndarray )):
207215 assert len (scale_factor ) == 4
208216 scale_factor = np .array (scale_factor )[None , :] # [1,4]
217+ scale_factor = torch .from_numpy (scale_factor ).to (
218+ device = torch .device (self .device ))
209219 dets [:, :4 ] /= scale_factor
210220
211221 if 'border' in img_metas [i ]:
@@ -216,7 +226,7 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
216226 y_off = img_metas [i ]['border' ][0 ]
217227 dets [:, [0 , 2 ]] -= x_off
218228 dets [:, [1 , 3 ]] -= y_off
219- dets [:, :4 ] *= (dets [:, :4 ] > 0 ). astype ( dets . dtype )
229+ dets [:, :4 ] *= (dets [:, :4 ] > 0 )
220230
221231 dets_results = bbox2result (dets , labels , len (self .CLASSES ))
222232
@@ -234,16 +244,14 @@ def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[dict],
234244 'export_postprocess_mask' , True )
235245 if not export_postprocess_mask :
236246 masks = End2EndModel .postprocessing_masks (
237- dets [:, :4 ], masks , ori_w , ori_h )
247+ dets [:, :4 ], masks , ori_w , ori_h , self . device )
238248 else :
239249 masks = masks [:, :img_h , :img_w ]
240250 # avoid to resize masks with zero dim
241251 if rescale and masks .shape [0 ] != 0 :
242- masks = masks .astype (np .float32 )
243- masks = torch .from_numpy (masks )
244252 masks = torch .nn .functional .interpolate (
245253 masks .unsqueeze (0 ), size = (ori_h , ori_w ))
246- masks = masks .squeeze (0 ). detach (). numpy ()
254+ masks = masks .squeeze (0 )
247255 if masks .dtype != bool :
248256 masks = masks >= 0.5
249257 segms_results = [[] for _ in range (len (self .CLASSES ))]
@@ -267,7 +275,6 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
267275 """
268276 outputs = self .wrapper ({self .input_name : imgs })
269277 outputs = self .wrapper .output_to_list (outputs )
270- outputs = [out .detach ().cpu ().numpy () for out in outputs ]
271278 return outputs
272279
273280 def show_result (self ,
0 commit comments