11# Copyright (c) OpenMMLab. All rights reserved.
2- from collections import defaultdict
2+ import math
3+ from itertools import groupby
34from typing import Dict , List , Optional , Tuple , Union
45
6+ import cv2
7+ import mmcv
58import numpy as np
6- from mmengine .structures import InstanceData
79
8- from mmpose .structures import PoseDataSample
9- from mmpose .visualization import PoseLocalVisualizer
1010from ...utils import FrameMessage
1111from ..base_visualizer_node import BaseVisualizerNode
1212from ..registry import NODES
1313
1414
15+ def imshow_bboxes (img ,
16+ bboxes ,
17+ labels = None ,
18+ colors = 'green' ,
19+ text_color = 'white' ,
20+ thickness = 1 ,
21+ font_scale = 0.5 ):
22+ """Draw bboxes with labels (optional) on an image. This is a wrapper of
23+ mmcv.imshow_bboxes.
24+
25+ Args:
26+ img (str or ndarray): The image to be displayed.
27+ bboxes (ndarray): ndarray of shape (k, 4), each row is a bbox in
28+ format [x1, y1, x2, y2].
29+ labels (str or list[str], optional): labels of each bbox.
30+ colors (list[str or tuple or :obj:`Color`]): A list of colors.
31+ text_color (str or tuple or :obj:`Color`): Color of texts.
32+ thickness (int): Thickness of lines.
33+ font_scale (float): Font scales of texts.
34+
35+ Returns:
36+ ndarray: The image with bboxes drawn on it.
37+ """
38+
39+ # adapt to mmcv.imshow_bboxes input format
40+ bboxes = np .split (
41+ bboxes , bboxes .shape [0 ], axis = 0 ) if bboxes .shape [0 ] > 0 else []
42+ if not isinstance (colors , list ):
43+ colors = [colors for _ in range (len (bboxes ))]
44+ colors = [mmcv .color_val (c ) for c in colors ]
45+ assert len (bboxes ) == len (colors )
46+
47+ img = mmcv .imshow_bboxes (
48+ img ,
49+ bboxes ,
50+ colors ,
51+ top_k = - 1 ,
52+ thickness = thickness ,
53+ show = False ,
54+ out_file = None )
55+
56+ if labels is not None :
57+ if not isinstance (labels , list ):
58+ labels = [labels for _ in range (len (bboxes ))]
59+ assert len (labels ) == len (bboxes )
60+
61+ for bbox , label , color in zip (bboxes , labels , colors ):
62+ if label is None :
63+ continue
64+ bbox_int = bbox [0 , :4 ].astype (np .int32 )
65+ # roughly estimate the proper font size
66+ text_size , text_baseline = cv2 .getTextSize (label ,
67+ cv2 .FONT_HERSHEY_DUPLEX ,
68+ font_scale , thickness )
69+ text_x1 = bbox_int [0 ]
70+ text_y1 = max (0 , bbox_int [1 ] - text_size [1 ] - text_baseline )
71+ text_x2 = bbox_int [0 ] + text_size [0 ]
72+ text_y2 = text_y1 + text_size [1 ] + text_baseline
73+ cv2 .rectangle (img , (text_x1 , text_y1 ), (text_x2 , text_y2 ), color ,
74+ cv2 .FILLED )
75+ cv2 .putText (img , label , (text_x1 , text_y2 - text_baseline ),
76+ cv2 .FONT_HERSHEY_DUPLEX , font_scale ,
77+ mmcv .color_val (text_color ), thickness )
78+
79+ return img
80+
81+
82+ def imshow_keypoints (img ,
83+ pose_result ,
84+ skeleton = None ,
85+ kpt_score_thr = 0.3 ,
86+ pose_kpt_color = None ,
87+ pose_link_color = None ,
88+ radius = 4 ,
89+ thickness = 1 ,
90+ show_keypoint_weight = False ):
91+ """Draw keypoints and links on an image.
92+
93+ Args:
94+ img (str or Tensor): The image to draw poses on. If an image array
95+ is given, id will be modified in-place.
96+ pose_result (list[kpts]): The poses to draw. Each element kpts is
97+ a set of K keypoints as an Kx3 numpy.ndarray, where each
98+ keypoint is represented as x, y, score.
99+ kpt_score_thr (float, optional): Minimum score of keypoints
100+ to be shown. Default: 0.3.
101+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
102+ the keypoint will not be drawn.
103+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
104+ links will not be drawn.
105+ thickness (int): Thickness of lines.
106+ """
107+
108+ img = mmcv .imread (img )
109+ img_h , img_w , _ = img .shape
110+
111+ for kpts in pose_result :
112+
113+ kpts = np .array (kpts , copy = False )
114+
115+ # draw each point on image
116+ if pose_kpt_color is not None :
117+ assert len (pose_kpt_color ) == len (kpts )
118+
119+ for kid , kpt in enumerate (kpts ):
120+ x_coord , y_coord , kpt_score = int (kpt [0 ]), int (kpt [1 ]), kpt [2 ]
121+
122+ if kpt_score < kpt_score_thr or pose_kpt_color [kid ] is None :
123+ # skip the point that should not be drawn
124+ continue
125+
126+ color = tuple (int (c ) for c in pose_kpt_color [kid ])
127+ if show_keypoint_weight :
128+ img_copy = img .copy ()
129+ cv2 .circle (img_copy , (int (x_coord ), int (y_coord )), radius ,
130+ color , - 1 )
131+ transparency = max (0 , min (1 , kpt_score ))
132+ cv2 .addWeighted (
133+ img_copy ,
134+ transparency ,
135+ img ,
136+ 1 - transparency ,
137+ 0 ,
138+ dst = img )
139+ else :
140+ cv2 .circle (img , (int (x_coord ), int (y_coord )), radius ,
141+ color , - 1 )
142+
143+ # draw links
144+ if skeleton is not None and pose_link_color is not None :
145+ assert len (pose_link_color ) == len (skeleton )
146+
147+ for sk_id , sk in enumerate (skeleton ):
148+ pos1 = (int (kpts [sk [0 ], 0 ]), int (kpts [sk [0 ], 1 ]))
149+ pos2 = (int (kpts [sk [1 ], 0 ]), int (kpts [sk [1 ], 1 ]))
150+
151+ if (pos1 [0 ] <= 0 or pos1 [0 ] >= img_w or pos1 [1 ] <= 0
152+ or pos1 [1 ] >= img_h or pos2 [0 ] <= 0 or pos2 [0 ] >= img_w
153+ or pos2 [1 ] <= 0 or pos2 [1 ] >= img_h
154+ or kpts [sk [0 ], 2 ] < kpt_score_thr
155+ or kpts [sk [1 ], 2 ] < kpt_score_thr
156+ or pose_link_color [sk_id ] is None ):
157+ # skip the link that should not be drawn
158+ continue
159+ color = tuple (int (c ) for c in pose_link_color [sk_id ])
160+ if show_keypoint_weight :
161+ img_copy = img .copy ()
162+ X = (pos1 [0 ], pos2 [0 ])
163+ Y = (pos1 [1 ], pos2 [1 ])
164+ mX = np .mean (X )
165+ mY = np .mean (Y )
166+ length = ((Y [0 ] - Y [1 ])** 2 + (X [0 ] - X [1 ])** 2 )** 0.5
167+ angle = math .degrees (math .atan2 (Y [0 ] - Y [1 ], X [0 ] - X [1 ]))
168+ stickwidth = 2
169+ polygon = cv2 .ellipse2Poly (
170+ (int (mX ), int (mY )), (int (length / 2 ), int (stickwidth )),
171+ int (angle ), 0 , 360 , 1 )
172+ cv2 .fillConvexPoly (img_copy , polygon , color )
173+ transparency = max (
174+ 0 , min (1 , 0.5 * (kpts [sk [0 ], 2 ] + kpts [sk [1 ], 2 ])))
175+ cv2 .addWeighted (
176+ img_copy ,
177+ transparency ,
178+ img ,
179+ 1 - transparency ,
180+ 0 ,
181+ dst = img )
182+ else :
183+ cv2 .line (img , pos1 , pos2 , color , thickness = thickness )
184+
185+ return img
186+
187+
15188@NODES .register_module ()
16189class ObjectVisualizerNode (BaseVisualizerNode ):
17190 """Visualize the bounding box and keypoints of objects.
@@ -91,12 +264,11 @@ def __init__(self,
91264 self .show_bbox = show_bbox
92265 self .show_keypoint = show_keypoint
93266 self .must_have_keypoint = must_have_keypoint
267+ self .radius = radius
268+ self .thickness = thickness
94269
95- self .visualizer = PoseLocalVisualizer (
96- name = 'webcam' , radius = radius , line_width = thickness )
97-
98- def draw (self , input_msg : FrameMessage ) -> np .ndarray :
99- canvas = input_msg .get_image ()
270+ def _draw_bbox (self , canvas : np .ndarray , input_msg : FrameMessage ):
271+ """Draw object bboxes."""
100272
101273 if self .must_have_keypoint :
102274 objects = input_msg .get_objects (
@@ -107,49 +279,63 @@ def draw(self, input_msg: FrameMessage) -> np.ndarray:
107279 if not objects :
108280 return canvas
109281
110- objects_by_label = defaultdict (list )
111- for object in objects :
112- objects_by_label [object ['label' ]].append (object )
113-
114- # draw objects of each category individually
115- for label , objects in objects_by_label .items ():
116- dataset_meta = objects [0 ]['dataset_meta' ]
117- dataset_meta ['bbox_color' ] = self .default_bbox_color .get (
118- label , self .bbox_color )
119- self .visualizer .set_dataset_meta (dataset_meta )
120-
121- # assign bboxes, keypoints and other predictions to data_sample
122- instances = InstanceData ()
123- instances .bboxes = np .stack ([object ['bbox' ] for object in objects ])
124- instances .labels = np .array (
125- [object ['class_id' ] for object in objects ])
126- if self .show_keypoint :
127- keypoints = [
128- object ['keypoints' ] for object in objects
129- if 'keypoints' in object
130- ]
131- if len (keypoints ):
132- instances .keypoints = np .stack (keypoints )
133- keypoint_scores = [
134- object ['keypoint_scores' ] for object in objects
135- if 'keypoint_scores' in object
136- ]
137- if len (keypoint_scores ):
138- instances .keypoint_scores = np .stack (keypoint_scores )
139- data_sample = PoseDataSample ()
140- data_sample .pred_instances = instances
141-
142- self .visualizer .add_datasample (
143- 'result' ,
282+ bboxes = [obj ['bbox' ] for obj in objects ]
283+ labels = [obj .get ('label' , None ) for obj in objects ]
284+ default_color = (0 , 255 , 0 )
285+
286+ # Get bbox colors
287+ if isinstance (self .bbox_color , dict ):
288+ colors = [
289+ self .bbox_color .get (label , default_color ) for label in labels
290+ ]
291+ else :
292+ colors = self .bbox_color
293+
294+ imshow_bboxes (
295+ canvas ,
296+ np .vstack (bboxes ),
297+ labels = labels ,
298+ colors = colors ,
299+ text_color = 'white' ,
300+ font_scale = 0.5 )
301+
302+ return canvas
303+
304+ def _draw_keypoint (self , canvas : np .ndarray , input_msg : FrameMessage ):
305+ """Draw object keypoints."""
306+ objects = input_msg .get_objects (lambda x : 'pose_model_cfg' in x )
307+
308+ # return if there is no object with keypoints
309+ if not objects :
310+ return canvas
311+
312+ for model_cfg , group in groupby (objects ,
313+ lambda x : x ['pose_model_cfg' ]):
314+ dataset_info = objects [0 ]['dataset_meta' ]
315+ keypoints = [
316+ np .concatenate (
317+ (obj ['keypoints' ], obj ['keypoint_scores' ][:, None ]),
318+ axis = 1 ) for obj in group
319+ ]
320+ imshow_keypoints (
144321 canvas ,
145- data_sample = data_sample ,
146- draw_gt = False ,
147- draw_heatmap = False ,
148- draw_bbox = True ,
149- show = False ,
150- wait_time = 0 ,
151- out_file = None ,
152- kpt_score_thr = self .kpt_thr )
153- canvas = self .visualizer .get_image ()
322+ keypoints ,
323+ skeleton = dataset_info ['skeleton_links' ],
324+ kpt_score_thr = self .kpt_thr ,
325+ pose_kpt_color = dataset_info ['keypoint_colors' ],
326+ pose_link_color = dataset_info ['skeleton_link_colors' ],
327+ radius = self .radius ,
328+ thickness = self .thickness )
329+
330+ return canvas
331+
332+ def draw (self , input_msg : FrameMessage ) -> np .ndarray :
333+ canvas = input_msg .get_image ()
334+
335+ if self .show_bbox :
336+ canvas = self ._draw_bbox (canvas , input_msg )
337+
338+ if self .show_keypoint :
339+ canvas = self ._draw_keypoint (canvas , input_msg )
154340
155341 return canvas
0 commit comments