Skip to content

Commit 1ee341b

Browse files
authored
save detection results to file using coco format #5782 (#5787)
* save detection results to file using coco format * update save docs
1 parent fa250ff commit 1ee341b

File tree

3 files changed

+83
-3
lines changed

3 files changed

+83
-3
lines changed

deploy/python/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc
9191
| --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False |
9292
| --cpu_threads | Option| 设置cpu线程数,默认为1 |
9393
| --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False |
94+
| --save_results | Option| 是否在文件夹下将图片的预测结果以JSON的形式保存 |
95+
9496

9597
说明:
9698

deploy/python/infer.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import os
1616
import yaml
1717
import glob
18+
import json
19+
from pathlib import Path
1820
from functools import reduce
1921

2022
import cv2
@@ -233,7 +235,8 @@ def predict_image(self,
233235
image_list,
234236
run_benchmark=False,
235237
repeats=1,
236-
visual=True):
238+
visual=True,
239+
save_file=None):
237240
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
238241
results = []
239242
for i in range(batch_loop_cnt):
@@ -293,6 +296,10 @@ def predict_image(self,
293296
if visual:
294297
print('Test iter {}'.format(i))
295298

299+
if save_file is not None:
300+
Path(self.output_dir).mkdir(exist_ok=True)
301+
self.format_coco_results(image_list, results, save_file=save_file)
302+
296303
results = self.merge_batch_result(results)
297304
return results
298305

@@ -313,7 +320,7 @@ def predict_video(self, video_file, camera_id):
313320
if not os.path.exists(self.output_dir):
314321
os.makedirs(self.output_dir)
315322
out_path = os.path.join(self.output_dir, video_out_name)
316-
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
323+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
317324
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
318325
index = 1
319326
while (1):
@@ -337,6 +344,68 @@ def predict_video(self, video_file, camera_id):
337344
break
338345
writer.release()
339346

347+
@staticmethod
348+
def format_coco_results(image_list, results, save_file=None):
349+
coco_results = []
350+
image_id = 0
351+
352+
for result in results:
353+
start_idx = 0
354+
for box_num in result['boxes_num']:
355+
idx_slice = slice(start_idx, start_idx + box_num)
356+
start_idx += box_num
357+
358+
image_file = image_list[image_id]
359+
image_id += 1
360+
361+
if 'boxes' in result:
362+
boxes = result['boxes'][idx_slice, :]
363+
per_result = [
364+
{
365+
'image_file': image_file,
366+
'bbox':
367+
[box[2], box[3], box[4] - box[2],
368+
box[5] - box[3]], # xyxy -> xywh
369+
'score': box[1],
370+
'category_id': int(box[0]),
371+
} for k, box in enumerate(boxes.tolist())
372+
]
373+
374+
elif 'segm' in result:
375+
import pycocotools.mask as mask_util
376+
377+
scores = result['score'][idx_slice].tolist()
378+
category_ids = result['label'][idx_slice].tolist()
379+
segms = result['segm'][idx_slice, :]
380+
rles = [
381+
mask_util.encode(
382+
np.array(
383+
mask[:, :, np.newaxis],
384+
dtype=np.uint8,
385+
order='F'))[0] for mask in segms
386+
]
387+
for rle in rles:
388+
rle['counts'] = rle['counts'].decode('utf-8')
389+
390+
per_result = [{
391+
'image_file': image_file,
392+
'segmentation': rle,
393+
'score': scores[k],
394+
'category_id': category_ids[k],
395+
} for k, rle in enumerate(rles)]
396+
397+
else:
398+
raise RuntimeError('')
399+
400+
# per_result = [item for item in per_result if item['score'] > threshold]
401+
coco_results.extend(per_result)
402+
403+
if save_file:
404+
with open(os.path.join(save_file), 'w') as f:
405+
json.dump(coco_results, f)
406+
407+
return coco_results
408+
340409

341410
class DetectorSOLOv2(Detector):
342411
"""
@@ -807,7 +876,10 @@ def main():
807876
if FLAGS.image_dir is None and FLAGS.image_file is not None:
808877
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
809878
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
810-
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=100)
879+
save_file = os.path.join(FLAGS.output_dir,
880+
'results.json') if FLAGS.save_results else None
881+
detector.predict_image(
882+
img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file)
811883
if not FLAGS.run_benchmark:
812884
detector.det_times.info(average=True)
813885
else:

deploy/python/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def argsparser():
156156
type=ast.literal_eval,
157157
default=False,
158158
help="Whether do random padding for action recognition.")
159+
parser.add_argument(
160+
"--save_results",
161+
type=bool,
162+
default=False,
163+
help="Whether save detection result to file using coco format")
164+
159165
return parser
160166

161167

0 commit comments

Comments
 (0)