diff --git a/configs/_base_/datasets/coco_openpose.py b/configs/_base_/datasets/coco_openpose.py new file mode 100644 index 0000000000..9aedd9f0e4 --- /dev/null +++ b/configs/_base_/datasets/coco_openpose.py @@ -0,0 +1,157 @@ +dataset_info = dict( + dataset_name='coco_openpose', + paper_info=dict( + author='Zhe, Cao and Tomas, Simon and ' + 'Shih-En, Wei and Yaser, Sheikh', + title='OpenPose: Realtime Multi-Person 2D Pose ' + 'Estimation using Part Affinity Fields', + container='IEEE Transactions on Pattern Analysis ' + 'and Machine Intelligence', + year='2019', + homepage='https://github.com/CMU-Perceptual-Computing-Lab/openpose/', + ), + keypoint_info={ + 0: + dict(name='nose', id=0, color=[255, 0, 85], type='upper', swap=''), + 1: + dict(name='neck', id=1, color=[255, 0, 0], type='upper', swap=''), + 2: + dict( + name='right_shoulder', + id=2, + color=[255, 85, 0], + type='upper', + swap='left_shoulder'), + 3: + dict( + name='right_elbow', + id=3, + color=[255, 170, 0], + type='upper', + swap='left_elbow'), + 4: + dict( + name='right_wrist', + id=4, + color=[255, 255, 0], + type='upper', + swap='left_wrist'), + 5: + dict( + name='left_shoulder', + id=5, + color=[170, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='left_elbow', + id=6, + color=[85, 255, 0], + type='upper', + swap='right_elbow'), + 7: + dict( + name='left_wrist', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 8: + dict( + name='right_hip', + id=8, + color=[255, 0, 170], + type='lower', + swap='left_hip'), + 9: + dict( + name='right_knee', + id=9, + color=[255, 0, 255], + type='lower', + swap='left_knee'), + 10: + dict( + name='right_ankle', + id=10, + color=[170, 0, 255], + type='lower', + swap='left_ankle'), + 11: + dict( + name='left_hip', + id=11, + color=[85, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='left_knee', + id=12, + color=[0, 0, 255], + type='lower', + swap='right_knee'), + 13: + dict( + name='left_ankle', + id=13, + color=[0, 85, 255], + type='lower', + swap='right_ankle'), + 14: + dict( + name='right_eye', + id=14, + color=[0, 255, 170], + type='upper', + swap='left_eye'), + 15: + dict( + name='left_eye', + id=15, + color=[0, 255, 255], + type='upper', + swap='right_eye'), + 16: + dict( + name='right_ear', + id=16, + color=[0, 170, 255], + type='upper', + swap='left_ear'), + 17: + dict( + name='left_ear', + id=17, + color=[0, 170, 255], + type='upper', + swap='right_ear'), + }, + skeleton_info={ + 0: dict(link=('neck', 'right_shoulder'), id=0, color=[255, 0, 85]), + 1: dict(link=('neck', 'left_shoulder'), id=1, color=[255, 0, 0]), + 2: + dict(link=('right_shoulder', 'right_elbow'), id=2, color=[255, 85, 0]), + 3: + dict(link=('right_elbow', 'right_wrist'), id=3, color=[255, 170, 0]), + 4: + dict(link=('left_shoulder', 'left_elbow'), id=4, color=[255, 255, 0]), + 5: dict(link=('left_elbow', 'left_wrist'), id=5, color=[170, 255, 0]), + 6: dict(link=('neck', 'right_hip'), id=6, color=[85, 255, 0]), + 7: dict(link=('right_hip', 'right_knee'), id=7, color=[0, 255, 0]), + 8: dict(link=('right_knee', 'right_ankle'), id=8, color=[0, 255, 85]), + 9: dict(link=('neck', 'left_hip'), id=9, color=[0, 255, 170]), + 10: dict(link=('left_hip', 'left_knee'), id=10, color=[0, 255, 225]), + 11: dict(link=('left_knee', 'left_ankle'), id=11, color=[0, 170, 255]), + 12: dict(link=('neck', 'nose'), id=12, color=[0, 85, 255]), + 13: dict(link=('nose', 'right_eye'), id=13, color=[0, 0, 255]), + 14: dict(link=('right_eye', 'right_ear'), id=14, color=[255, 0, 170]), + 15: dict(link=('nose', 'left_eye'), id=15, color=[170, 0, 255]), + 16: dict(link=('left_eye', 'left_ear'), id=16, color=[255, 0, 255]), + }, + joint_weights=[1.] * 18, + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, + 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089, 0.082 + ]) diff --git a/demo/image_demo.py b/demo/image_demo.py index 0aa4a9e057..bfbc808b1e 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -25,6 +25,29 @@ def parse_args(): action='store_true', default=False, help='Whether to show the index of keypoints') + parser.add_argument( + '--skeleton-style', + default='mmpose', + type=str, + choices=['mmpose', 'openpose'], + help='Skeleton style selection') + parser.add_argument( + '--kpt-thr', + type=float, + default=0.3, + help='Visualizing keypoint thresholds') + parser.add_argument( + '--radius', + type=int, + default=3, + help='Keypoint radius for visualization') + parser.add_argument( + '--thickness', + type=int, + default=1, + help='Link thickness for visualization') + parser.add_argument( + '--alpha', type=float, default=0.8, help='The transparency of bboxes') parser.add_argument( '--show', action='store_true', @@ -50,8 +73,13 @@ def main(): cfg_options=cfg_options) # init visualizer + model.cfg.visualizer.radius = args.radius + model.cfg.visualizer.alpha = args.alpha + model.cfg.visualizer.line_width = args.thickness + visualizer = VISUALIZERS.build(model.cfg.visualizer) - visualizer.set_dataset_meta(model.dataset_meta) + visualizer.set_dataset_meta( + model.dataset_meta, skeleton_style=args.skeleton_style) # inference a single image batch_results = inference_topdown(model, args.img) @@ -65,8 +93,10 @@ def main(): data_sample=results, draw_gt=False, draw_bbox=True, + kpt_thr=args.kpt_thr, draw_heatmap=args.draw_heatmap, show_kpt_idx=args.show_kpt_idx, + skeleton_style=args.skeleton_style, show=args.show, out_file=args.out_file) diff --git a/demo/topdown_demo_with_mmdet.py b/demo/topdown_demo_with_mmdet.py index 418f3695b9..f0938000f1 100644 --- a/demo/topdown_demo_with_mmdet.py +++ b/demo/topdown_demo_with_mmdet.py @@ -55,10 +55,11 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer, draw_heatmap=args.draw_heatmap, draw_bbox=args.draw_bbox, show_kpt_idx=args.show_kpt_idx, + skeleton_style=args.skeleton_style, show=args.show, wait_time=show_interval, out_file=out_file, - kpt_score_thr=args.kpt_thr) + kpt_thr=args.kpt_thr) # if there is no instance detected, return None return data_samples.get('pred_instances', None) @@ -110,7 +111,10 @@ def main(): default=0.3, help='IoU threshold for bounding box NMS') parser.add_argument( - '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold') + '--kpt-thr', + type=float, + default=0.3, + help='Visualizing keypoint thresholds') parser.add_argument( '--draw-heatmap', action='store_true', @@ -121,6 +125,12 @@ def main(): action='store_true', default=False, help='Whether to show the index of keypoints') + parser.add_argument( + '--skeleton-style', + default='mmpose', + type=str, + choices=['mmpose', 'openpose'], + help='Skeleton style selection') parser.add_argument( '--radius', type=int, @@ -131,6 +141,8 @@ def main(): type=int, default=1, help='Link thickness for visualization') + parser.add_argument( + '--alpha', type=float, default=0.8, help='The transparency of bboxes') parser.add_argument( '--draw-bbox', action='store_true', help='Draw bboxes of instances') @@ -164,11 +176,14 @@ def main(): # init visualizer pose_estimator.cfg.visualizer.radius = args.radius + pose_estimator.cfg.visualizer.alpha = args.alpha pose_estimator.cfg.visualizer.line_width = args.thickness + visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer) # the dataset_meta is loaded from the checkpoint and # then pass to the model in init_pose_estimator - visualizer.set_dataset_meta(pose_estimator.dataset_meta) + visualizer.set_dataset_meta( + pose_estimator.dataset_meta, skeleton_style=args.skeleton_style) input_type = mimetypes.guess_type(args.input)[0].split('/')[0] if input_type == 'image': diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py index d99dcc1b68..2b54c7907b 100644 --- a/mmpose/apis/inferencers/base_mmpose_inferencer.py +++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py @@ -314,7 +314,7 @@ def visualize(self, show=show, wait_time=wait_time, out_file=out_file, - kpt_score_thr=kpt_thr) + kpt_thr=kpt_thr) results.append(visualization) if show and not hasattr(self, '_window_close_cid'): diff --git a/mmpose/engine/hooks/visualization_hook.py b/mmpose/engine/hooks/visualization_hook.py index 0458bb9385..24b845f282 100644 --- a/mmpose/engine/hooks/visualization_hook.py +++ b/mmpose/engine/hooks/visualization_hook.py @@ -50,7 +50,7 @@ def __init__( self, enable: bool = False, interval: int = 50, - score_thr: float = 0.3, + kpt_thr: float = 0.3, show: bool = False, wait_time: float = 0., out_dir: Optional[str] = None, @@ -58,7 +58,7 @@ def __init__( ): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval - self.score_thr = score_thr + self.kpt_thr = kpt_thr self.show = show if self.show: # No need to think about vis backends. @@ -112,7 +112,7 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, draw_heatmap=True, show=self.show, wait_time=self.wait_time, - kpt_score_thr=self.score_thr, + kpt_thr=self.kpt_thr, step=total_curr_iter) def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, @@ -163,6 +163,6 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, draw_bbox=True, draw_heatmap=True, wait_time=self.wait_time, - kpt_score_thr=self.score_thr, + kpt_thr=self.kpt_thr, out_file=out_file, step=self._test_index) diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py index 1743ef7997..0e6c3f1bcb 100644 --- a/mmpose/visualization/local_visualizer.py +++ b/mmpose/visualization/local_visualizer.py @@ -10,6 +10,7 @@ from mmengine.structures import InstanceData, PixelData from mmengine.visualization import Visualizer +from mmpose.datasets.datasets.utils import parse_pose_metainfo from mmpose.registry import VISUALIZERS from mmpose.structures import PoseDataSample from .simcc_vis import SimCCVisualizer @@ -130,13 +131,20 @@ def __init__(self, # it will override the default value. self.dataset_meta = {} - def set_dataset_meta(self, dataset_meta: Dict): + def set_dataset_meta(self, + dataset_meta: Dict, + skeleton_style: str = 'mmpose'): """Assign dataset_meta to the visualizer. The default visualization settings will be overridden. Args: dataset_meta (dict): meta information of dataset. """ + if dataset_meta.get( + 'dataset_name') == 'coco' and skeleton_style == 'openpose': + dataset_meta = parse_pose_metainfo( + dict(from_file='configs/_base_/datasets/coco_openpose.py')) + if isinstance(dataset_meta, dict): self.dataset_meta = dataset_meta.copy() self.bbox_color = dataset_meta.get('bbox_color', self.bbox_color) @@ -211,16 +219,21 @@ def _draw_instances_bbox(self, image: np.ndarray, def _draw_instances_kpts(self, image: np.ndarray, instances: InstanceData, - kpt_score_thr: float = 0.3, - show_kpt_idx: bool = False): + kpt_thr: float = 0.3, + show_kpt_idx: bool = False, + skeleton_style: str = 'mmpose'): """Draw keypoints and skeletons (optional) of GT or prediction. Args: image (np.ndarray): The image to draw. instances (:obj:`InstanceData`): Data structure for instance-level annotations or predictions. - kpt_score_thr (float, optional): Minimum score of keypoints + kpt_thr (float, optional): Minimum threshold of keypoints to be shown. Default: 0.3. + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` + skeleton_style (str): Skeleton style selection. Defaults to + ``'mmpose'`` Returns: np.ndarray: the drawn image which channel is RGB. @@ -242,107 +255,130 @@ def _draw_instances_kpts(self, keypoints_visible = instances.keypoints_visible else: keypoints_visible = [np.ones(len(kpts)) for kpts in keypoints] - - for kpts, score, visible in zip(keypoints, scores, - keypoints_visible): - kpts = np.array(kpts, copy=False) - - if self.kpt_color is None or isinstance(self.kpt_color, str): - kpt_color = [self.kpt_color] * len(kpts) - elif len(self.kpt_color) == len(kpts): - kpt_color = self.kpt_color + keypoints_visible = np.array(keypoints_visible) + + kpt_info = np.concatenate((keypoints, np.array(scores).reshape( + -1, len(scores[0]), 1), np.array(keypoints_visible).reshape( + -1, len(keypoints_visible[0]), 1)), + axis=-1) + + if skeleton_style == 'openpose': + # compute neck joint + neck = np.mean(kpt_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + if kpt_info[:, 5, 2] < kpt_thr or kpt_info[:, 6, 2] < kpt_thr: + neck[:, 2] = 0 + # neck visible when visualizing gt + if kpt_info[:, 5, 3] < kpt_thr or kpt_info[:, 6, 3] < kpt_thr: + neck[:, 3] = 0 + new_kpt_info = np.insert(kpt_info, 17, neck, axis=1) + + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_kpt_info[:, openpose_idx] = new_kpt_info[:, mmpose_idx] + kpt_info = new_kpt_info + + kpts, score, visible = kpt_info[0, :, :2], kpt_info[ + 0, :, 2], kpt_info[0, :, 3] + if self.kpt_color is None or isinstance(self.kpt_color, str): + kpt_color = [self.kpt_color] * len(kpts) + elif len(self.kpt_color) == len(kpts): + kpt_color = self.kpt_color + else: + raise ValueError(f'the length of kpt_color ' + f'({len(self.kpt_color)}) does not matches ' + f'that of keypoints ({len(kpts)})') + + # draw each point on image + for kid, kpt in enumerate(kpts): + if score[kid] < kpt_thr or not visible[ + kid] or kpt_color[kid] is None: + # skip the point that should not be drawn + continue + + color = kpt_color[kid] + if not isinstance(color, str): + color = tuple(int(c) for c in color) + transparency = self.alpha + if self.show_keypoint_weight: + transparency *= max(0, min(1, score[kid])) + self.draw_circles( + kpt, + radius=np.array([self.radius]), + face_colors=color, + edge_colors=color, + alpha=transparency, + line_widths=self.radius) + if show_kpt_idx: + self.draw_texts( + str(kid), + kpt, + colors=color, + font_sizes=self.radius * 3, + vertical_alignments='bottom', + horizontal_alignments='center') + + # draw links + if self.skeleton is not None and self.link_color is not None: + if self.link_color is None or isinstance(self.link_color, str): + link_color = [self.link_color] * len(self.skeleton) + elif len(self.link_color) == len(self.skeleton): + link_color = self.link_color else: raise ValueError( - f'the length of kpt_color ' - f'({len(self.kpt_color)}) does not matches ' - f'that of keypoints ({len(kpts)})') - - # draw each point on image - for kid, kpt in enumerate(kpts): - if score[kid] < kpt_score_thr or not visible[ - kid] or kpt_color[kid] is None: - # skip the point that should not be drawn + f'the length of link_color ' + f'({len(self.link_color)}) does not matches ' + f'that of skeleton ({len(self.skeleton)})') + + for sk_id, sk in enumerate(self.skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + if not (visible[sk[0]] and visible[sk[1]]): continue - color = kpt_color[kid] + if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 + or pos1[1] >= img_h or pos2[0] <= 0 + or pos2[0] >= img_w or pos2[1] <= 0 + or pos2[1] >= img_h or score[sk[0]] < kpt_thr + or score[sk[1]] < kpt_thr + or link_color[sk_id] is None): + # skip the link that should not be drawn + continue + X = np.array((pos1[0], pos2[0])) + Y = np.array((pos1[1], pos2[1])) + color = link_color[sk_id] if not isinstance(color, str): color = tuple(int(c) for c in color) transparency = self.alpha if self.show_keypoint_weight: - transparency *= max(0, min(1, score[kid])) - self.draw_circles( - kpt, - radius=np.array([self.radius]), - face_colors=color, - edge_colors=color, - alpha=transparency, - line_widths=self.radius) - if show_kpt_idx: - self.draw_texts( - str(kid), - kpt, - colors=color, - font_sizes=self.radius * 3, - vertical_alignments='bottom', - horizontal_alignments='center') - - # draw links - if self.skeleton is not None and self.link_color is not None: - if self.link_color is None or isinstance( - self.link_color, str): - link_color = [self.link_color] * len(self.skeleton) - elif len(self.link_color) == len(self.skeleton): - link_color = self.link_color + transparency *= max( + 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]]))) + + if skeleton_style == 'openpose': + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees( + math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygons = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), int(angle), 0, + 360, 1) + + self.draw_polygons( + polygons, + edge_colors=color, + face_colors=color, + alpha=transparency) + else: - raise ValueError( - f'the length of link_color ' - f'({len(self.link_color)}) does not matches ' - f'that of skeleton ({len(self.skeleton)})') - - for sk_id, sk in enumerate(self.skeleton): - pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) - pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) - if not (visible[sk[0]] and visible[sk[1]]): - continue - - if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 - or pos1[1] >= img_h or pos2[0] <= 0 - or pos2[0] >= img_w or pos2[1] <= 0 - or pos2[1] >= img_h - or score[sk[0]] < kpt_score_thr - or score[sk[1]] < kpt_score_thr - or link_color[sk_id] is None): - # skip the link that should not be drawn - continue - X = np.array((pos1[0], pos2[0])) - Y = np.array((pos1[1], pos2[1])) - color = link_color[sk_id] - if not isinstance(color, str): - color = tuple(int(c) for c in color) - if self.show_keypoint_weight: - - mX = np.mean(X) - mY = np.mean(Y) - length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 - angle = math.degrees( - math.atan2(Y[0] - Y[1], X[0] - X[1])) - stickwidth = 2 - polygons = cv2.ellipse2Poly( - (int(mX), int(mY)), - (int(length / 2), int(stickwidth)), int(angle), - 0, 360, 1) - transparency = self.alpha - transparency *= max( - 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]]))) - self.draw_polygons( - polygons, - edge_colors=color, - face_colors=color, - alpha=transparency) - - else: - self.draw_lines( - X, Y, color, line_widths=self.line_width) + self.draw_lines( + X, Y, color, line_widths=self.line_width) return self.get_image() @@ -410,10 +446,11 @@ def add_datasample(self, draw_heatmap: bool = False, draw_bbox: bool = False, show_kpt_idx: bool = False, + skeleton_style: str = 'mmpose', show: bool = False, wait_time: float = 0, out_file: Optional[str] = None, - kpt_score_thr: float = 0.3, + kpt_thr: float = 0.3, step: int = 0) -> None: """Draw datasample and save to all backends. @@ -439,12 +476,16 @@ def add_datasample(self, ``False`` draw_heatmap (bool): Whether to draw heatmaps. Defaults to ``False`` + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` + skeleton_style (str): Skeleton style selection. Defaults to + ``'mmpose'`` show (bool): Whether to display the drawn image. Default to ``False`` wait_time (float): The interval of show (s). Defaults to 0 out_file (str): Path to output file. Defaults to ``None`` - pred_score_thr (float): The threshold to visualize the bboxes - and masks. Defaults to 0.3 + kpt_thr (float, optional): Minimum threshold of keypoints + to be shown. Default: 0.3. step (int): Global step value to record. Defaults to 0 """ @@ -458,8 +499,8 @@ def add_datasample(self, # draw bboxes & keypoints if 'gt_instances' in data_sample: gt_img_data = self._draw_instances_kpts( - gt_img_data, data_sample.gt_instances, kpt_score_thr, - show_kpt_idx) + gt_img_data, data_sample.gt_instances, kpt_thr, + show_kpt_idx, skeleton_style) if draw_bbox: gt_img_data = self._draw_instances_bbox( gt_img_data, data_sample.gt_instances) @@ -479,8 +520,8 @@ def add_datasample(self, # draw bboxes & keypoints if 'pred_instances' in data_sample: pred_img_data = self._draw_instances_kpts( - pred_img_data, data_sample.pred_instances, kpt_score_thr, - show_kpt_idx) + pred_img_data, data_sample.pred_instances, kpt_thr, + show_kpt_idx, skeleton_style) if draw_bbox: pred_img_data = self._draw_instances_bbox( pred_img_data, data_sample.pred_instances) diff --git a/projects/mmpose4aigc/openpose_visualization.py b/projects/mmpose4aigc/openpose_visualization.py index 3410f2d1f3..b7fde6eae0 100644 --- a/projects/mmpose4aigc/openpose_visualization.py +++ b/projects/mmpose4aigc/openpose_visualization.py @@ -64,7 +64,8 @@ def mmpose_to_openpose_visualization(args, img_path, detector, pose_estimator): # compute neck joint neck = (keypoints[:, 5] + keypoints[:, 6]) / 2 - neck[:, 2] = keypoints[:, 5, 2] * keypoints[:, 6, 2] + if keypoints[:, 5, 2] < args.kpt_thr or keypoints[:, 6, 2] < args.kpt_thr: + neck[:, 2] = 0 # 17 keypoints to 18 keypoints new_keypoints = np.insert(keypoints[:, ], 17, neck, axis=1) @@ -83,27 +84,26 @@ def mmpose_to_openpose_visualization(args, img_path, detector, pose_estimator): num_instance = new_keypoints.shape[0] # draw keypoints - cur_black_img = black_img.copy() for i, j in product(range(num_instance), range(num_openpose_kpt)): x, y, conf = new_keypoints[i][j] - if conf <= 1e-5: - continue - cv2.circle(cur_black_img, (int(x), int(y)), 4, colors[j], thickness=-1) - black_img = cv2.addWeighted(black_img, 0.3, cur_black_img, 0.7, 0) + if conf > args.kpt_thr: + cv2.circle(black_img, (int(x), int(y)), 4, colors[j], thickness=-1) # draw links cur_black_img = black_img.copy() for i, link_idx in product(range(num_instance), range(num_link)): - Y = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 0] - X = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 1] - mX = np.mean(X) - mY = np.mean(Y) - length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5 - angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) - polygon = cv2.ellipse2Poly( - (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, - 360, 1) - cv2.fillConvexPoly(cur_black_img, polygon, colors[link_idx]) + conf = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 2] + if np.sum(conf > args.kpt_thr) == 2: + Y = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 0] + X = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 1] + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly( + (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), + 0, 360, 1) + cv2.fillConvexPoly(cur_black_img, polygon, colors[link_idx]) black_img = cv2.addWeighted(black_img, 0.4, cur_black_img, 0.6, 0) # save image @@ -133,7 +133,7 @@ def main(): parser.add_argument( '--bbox-thr', type=float, - default=0.3, + default=0.4, help='Bounding box score threshold') parser.add_argument( '--nms-thr', @@ -141,7 +141,7 @@ def main(): default=0.3, help='IoU threshold for bounding box NMS') parser.add_argument( - '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold') + '--kpt-thr', type=float, default=0.4, help='Keypoint score threshold') assert has_mmdet, 'Please install mmdet to run the demo.'