1515import os
1616import yaml
1717import glob
18+ import json
19+ from pathlib import Path
1820from functools import reduce
1921
2022import cv2
@@ -233,7 +235,8 @@ def predict_image(self,
233235 image_list ,
234236 run_benchmark = False ,
235237 repeats = 1 ,
236- visual = True ):
238+ visual = True ,
239+ save_file = None ):
237240 batch_loop_cnt = math .ceil (float (len (image_list )) / self .batch_size )
238241 results = []
239242 for i in range (batch_loop_cnt ):
@@ -293,6 +296,10 @@ def predict_image(self,
293296 if visual :
294297 print ('Test iter {}' .format (i ))
295298
299+ if save_file is not None :
300+ Path (self .output_dir ).mkdir (exist_ok = True )
301+ self .format_coco_results (image_list , results , save_file = save_file )
302+
296303 results = self .merge_batch_result (results )
297304 return results
298305
@@ -313,7 +320,7 @@ def predict_video(self, video_file, camera_id):
313320 if not os .path .exists (self .output_dir ):
314321 os .makedirs (self .output_dir )
315322 out_path = os .path .join (self .output_dir , video_out_name )
316- fourcc = cv2 .VideoWriter_fourcc (* 'mp4v' )
323+ fourcc = cv2 .VideoWriter_fourcc (* 'mp4v' )
317324 writer = cv2 .VideoWriter (out_path , fourcc , fps , (width , height ))
318325 index = 1
319326 while (1 ):
@@ -337,6 +344,68 @@ def predict_video(self, video_file, camera_id):
337344 break
338345 writer .release ()
339346
347+ @staticmethod
348+ def format_coco_results (image_list , results , save_file = None ):
349+ coco_results = []
350+ image_id = 0
351+
352+ for result in results :
353+ start_idx = 0
354+ for box_num in result ['boxes_num' ]:
355+ idx_slice = slice (start_idx , start_idx + box_num )
356+ start_idx += box_num
357+
358+ image_file = image_list [image_id ]
359+ image_id += 1
360+
361+ if 'boxes' in result :
362+ boxes = result ['boxes' ][idx_slice , :]
363+ per_result = [
364+ {
365+ 'image_file' : image_file ,
366+ 'bbox' :
367+ [box [2 ], box [3 ], box [4 ] - box [2 ],
368+ box [5 ] - box [3 ]], # xyxy -> xywh
369+ 'score' : box [1 ],
370+ 'category_id' : int (box [0 ]),
371+ } for k , box in enumerate (boxes .tolist ())
372+ ]
373+
374+ elif 'segm' in result :
375+ import pycocotools .mask as mask_util
376+
377+ scores = result ['score' ][idx_slice ].tolist ()
378+ category_ids = result ['label' ][idx_slice ].tolist ()
379+ segms = result ['segm' ][idx_slice , :]
380+ rles = [
381+ mask_util .encode (
382+ np .array (
383+ mask [:, :, np .newaxis ],
384+ dtype = np .uint8 ,
385+ order = 'F' ))[0 ] for mask in segms
386+ ]
387+ for rle in rles :
388+ rle ['counts' ] = rle ['counts' ].decode ('utf-8' )
389+
390+ per_result = [{
391+ 'image_file' : image_file ,
392+ 'segmentation' : rle ,
393+ 'score' : scores [k ],
394+ 'category_id' : category_ids [k ],
395+ } for k , rle in enumerate (rles )]
396+
397+ else :
398+ raise RuntimeError ('' )
399+
400+ # per_result = [item for item in per_result if item['score'] > threshold]
401+ coco_results .extend (per_result )
402+
403+ if save_file :
404+ with open (os .path .join (save_file ), 'w' ) as f :
405+ json .dump (coco_results , f )
406+
407+ return coco_results
408+
340409
341410class DetectorSOLOv2 (Detector ):
342411 """
@@ -807,7 +876,10 @@ def main():
807876 if FLAGS .image_dir is None and FLAGS .image_file is not None :
808877 assert FLAGS .batch_size == 1 , "batch_size should be 1, when image_file is not None"
809878 img_list = get_test_images (FLAGS .image_dir , FLAGS .image_file )
810- detector .predict_image (img_list , FLAGS .run_benchmark , repeats = 100 )
879+ save_file = os .path .join (FLAGS .output_dir ,
880+ 'results.json' ) if FLAGS .save_results else None
881+ detector .predict_image (
882+ img_list , FLAGS .run_benchmark , repeats = 100 , save_file = save_file )
811883 if not FLAGS .run_benchmark :
812884 detector .det_times .info (average = True )
813885 else :
0 commit comments