|
27 | 27 |
|
28 | 28 | from .map_utils import prune_zero_padding, DetectionMAP |
29 | 29 | from .coco_utils import get_infer_results, cocoapi_eval |
30 | | -from .widerface_utils import face_eval_run |
| 30 | +from .widerface_utils import (face_eval_run, image_eval, img_pr_info, |
| 31 | + dataset_pr_info, voc_ap) |
31 | 32 | from ppdet.data.source.category import get_categories |
32 | 33 | from ppdet.modeling.rbox_utils import poly2rbox_np |
33 | 34 |
|
@@ -337,22 +338,93 @@ def get_results(self): |
337 | 338 |
|
338 | 339 |
|
339 | 340 | class WiderFaceMetric(Metric): |
340 | | - def __init__(self, image_dir, anno_file, multi_scale=True): |
341 | | - self.image_dir = image_dir |
342 | | - self.anno_file = anno_file |
343 | | - self.multi_scale = multi_scale |
344 | | - self.clsid2catid, self.catid2name = get_categories('widerface') |
345 | | - |
346 | | - def update(self, model): |
347 | | - |
348 | | - face_eval_run( |
349 | | - model, |
350 | | - self.image_dir, |
351 | | - self.anno_file, |
352 | | - pred_dir='output/pred', |
353 | | - eval_mode='widerface', |
354 | | - multi_scale=self.multi_scale) |
| 341 | + def __init__(self, iou_thresh=0.5): |
| 342 | + self.iou_thresh = iou_thresh |
| 343 | + self.reset() |
355 | 344 |
|
| 345 | + def reset(self): |
| 346 | + self.pred_boxes_list = [] |
| 347 | + self.gt_boxes_list = [] |
| 348 | + self.aps = [] |
| 349 | + |
| 350 | + self.hard_ignore_list = [] |
| 351 | + self.medium_ignore_list = [] |
| 352 | + self.easy_ignore_list = [] |
| 353 | + |
| 354 | + def update(self, data, outs): |
| 355 | + batch_pred_bboxes = outs['bbox'] |
| 356 | + batch_pred_bboxes_num = outs['bbox_num'] |
| 357 | + assert len(batch_pred_bboxes_num) == len(data['gt_bbox']) |
| 358 | + batch_size = len(data['gt_bbox']) |
| 359 | + box_cnt = 0 |
| 360 | + for batch_id in range(batch_size): |
| 361 | + pred_bboxes_num = batch_pred_bboxes_num[batch_id] |
| 362 | + pred_bboxes = batch_pred_bboxes[box_cnt: box_cnt + |
| 363 | + pred_bboxes_num].numpy() |
| 364 | + box_cnt += pred_bboxes_num |
| 365 | + |
| 366 | + det_conf = pred_bboxes[:, 1] |
| 367 | + det_xmin = pred_bboxes[:, 2] |
| 368 | + det_ymin = pred_bboxes[:, 3] |
| 369 | + det_xmax = pred_bboxes[:, 4] |
| 370 | + det_ymax = pred_bboxes[:, 5] |
| 371 | + det = np.column_stack((det_xmin, det_ymin, det_xmax, |
| 372 | + det_ymax, det_conf)) |
| 373 | + self.pred_boxes_list.append(det) # xyxy conf |
| 374 | + self.gt_boxes_list.append(data['gt_ori_bbox'][batch_id].numpy()) # xywh |
| 375 | + self.hard_ignore_list.append( |
| 376 | + data['gt_hard_ignore'][batch_id].numpy()) |
| 377 | + self.medium_ignore_list.append( |
| 378 | + data['gt_medium_ignore'][batch_id].numpy()) |
| 379 | + self.easy_ignore_list.append( |
| 380 | + data['gt_easy_ignore'][batch_id].numpy()) |
| 381 | + |
| 382 | + def accumulate(self): |
| 383 | + total_num = len(self.gt_boxes_list) |
| 384 | + settings = ['easy', 'medium', 'hard'] |
| 385 | + setting_ingores = [self.easy_ignore_list, |
| 386 | + self.medium_ignore_list, |
| 387 | + self.hard_ignore_list] |
| 388 | + thresh_num = 1000 |
| 389 | + aps = [] |
| 390 | + for setting_id in range(3): |
| 391 | + count_face = 0 |
| 392 | + pr_curve = np.zeros((thresh_num, 2)).astype(np.float32) |
| 393 | + gt_ignore_list = setting_ingores[setting_id] |
| 394 | + for i in range(total_num): |
| 395 | + pred_boxes = self.pred_boxes_list[i] # xyxy conf |
| 396 | + gt_boxes = self.gt_boxes_list[i] # xywh |
| 397 | + ignore = gt_ignore_list[i] |
| 398 | + count_face += np.sum(ignore) |
| 399 | + |
| 400 | + if len(gt_boxes) == 0 or len(pred_boxes) == 0: |
| 401 | + continue |
| 402 | + pred_recall, proposal_list = image_eval(pred_boxes, gt_boxes, |
| 403 | + ignore, self.iou_thresh) |
| 404 | + _img_pr_info = img_pr_info(thresh_num, pred_boxes, |
| 405 | + proposal_list, pred_recall) |
| 406 | + pr_curve += _img_pr_info |
| 407 | + pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face) |
| 408 | + |
| 409 | + propose = pr_curve[:, 0] |
| 410 | + recall = pr_curve[:, 1] |
| 411 | + |
| 412 | + ap = voc_ap(recall, propose) |
| 413 | + aps.append(ap) |
| 414 | + self.aps = aps |
| 415 | + |
| 416 | + def log(self): |
| 417 | + logger.info("==================== Results ====================") |
| 418 | + logger.info("Easy Val AP: {}".format(self.aps[0])) |
| 419 | + logger.info("Medium Val AP: {}".format(self.aps[1])) |
| 420 | + logger.info("Hard Val AP: {}".format(self.aps[2])) |
| 421 | + logger.info("=================================================") |
| 422 | + |
| 423 | + def get_results(self): |
| 424 | + return { |
| 425 | + 'easy_ap': self.aps[0], |
| 426 | + 'medium_ap': self.aps[1], |
| 427 | + 'hard_ap': self.aps[2]} |
356 | 428 |
|
357 | 429 | class RBoxMetric(Metric): |
358 | 430 | def __init__(self, anno_file, **kwargs): |
|
0 commit comments