Skip to content

Commit 3af7fed

Browse files
authored
[Fix] Use opencv backend in Webcam API visualization (#2089)
1 parent cd71020 commit 3af7fed

File tree

5 files changed

+250
-59
lines changed

5 files changed

+250
-59
lines changed

demo/docs/webcam_demo.md renamed to demo/docs/webcam_api_demo.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ Launch the demo from the mmpose root directory:
1212

1313
```shell
1414
# Run webcam demo with GPU
15-
python demo/webcam_demo.py
15+
python demo/webcam_api_demo.py
1616

1717
# Run webcam demo with CPU
18-
python demo/webcam_demo.py --cpu
18+
python demo/webcam_api_demo.py --cpu
1919
```
2020

2121
The command above will use the default config file `demo/webcam_cfg/pose_estimation.py`. You can also specify the config file in the command:
2222

2323
```shell
2424
# Use the config "pose_tracking.py" for higher infererence speed
25-
python demo/webcam_demo.py --config demo/webcam_cfg/pose_estimation.py
25+
python demo/webcam_api_demo.py --config demo/webcam_cfg/pose_estimation.py
2626
```
2727

2828
### Hotkeys
@@ -36,7 +36,7 @@ python demo/webcam_demo.py --config demo/webcam_cfg/pose_estimation.py
3636
| m | Show the monitoring information. |
3737
| q | Exit. |
3838

39-
Note that the demo will automatically save the output video into a file `webcam_demo.mp4`.
39+
Note that the demo will automatically save the output video into a file `webcam_api_demo.mp4`.
4040

4141
### Usage and configuarations
4242

@@ -103,5 +103,5 @@ Detailed configurations can be found in the config file.
103103
Run follow command for a quick test of video capturing and displaying.
104104

105105
```shell
106-
python demo/webcam_demo.py --config demo/webcam_cfg/test_camera.py
106+
python demo/webcam_api_demo.py --config demo/webcam_cfg/test_camera.py
107107
```

demo/webcam_demo.py renamed to demo/webcam_api_demo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import logging
4+
import warnings
45
from argparse import ArgumentParser
56

67
from mmengine import Config, DictAction
@@ -51,6 +52,10 @@ def set_device(cfg: Config, device: str):
5152

5253

5354
def run():
55+
56+
warnings.warn('The Webcam API will be deprecated in future. ',
57+
DeprecationWarning)
58+
5459
args = parse_args()
5560
cfg = Config.fromfile(args.config)
5661
cfg.merge_from_dict(args.cfg_options)

demo/webcam_cfg/pose_estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@
129129
dict(
130130
type='RecorderNode',
131131
name='recorder',
132-
out_video_file='webcam_demo.mp4',
132+
out_video_file='webcam_api_demo.mp4',
133133
input_buffer='display',
134134
output_buffer='_display_'
135135
# `_display_` is an executor-reserved buffer

demo/webcam_cfg/test_camera.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
dict(
1717
type='RecorderNode',
1818
name='recorder',
19-
out_video_file='webcam_output.mp4',
19+
out_video_file='webcam_api_output.mp4',
2020
input_buffer='display',
2121
output_buffer='_display_')
2222
])

mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py

Lines changed: 238 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,190 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from collections import defaultdict
2+
import math
3+
from itertools import groupby
34
from typing import Dict, List, Optional, Tuple, Union
45

6+
import cv2
7+
import mmcv
58
import numpy as np
6-
from mmengine.structures import InstanceData
79

8-
from mmpose.structures import PoseDataSample
9-
from mmpose.visualization import PoseLocalVisualizer
1010
from ...utils import FrameMessage
1111
from ..base_visualizer_node import BaseVisualizerNode
1212
from ..registry import NODES
1313

1414

15+
def imshow_bboxes(img,
16+
bboxes,
17+
labels=None,
18+
colors='green',
19+
text_color='white',
20+
thickness=1,
21+
font_scale=0.5):
22+
"""Draw bboxes with labels (optional) on an image. This is a wrapper of
23+
mmcv.imshow_bboxes.
24+
25+
Args:
26+
img (str or ndarray): The image to be displayed.
27+
bboxes (ndarray): ndarray of shape (k, 4), each row is a bbox in
28+
format [x1, y1, x2, y2].
29+
labels (str or list[str], optional): labels of each bbox.
30+
colors (list[str or tuple or :obj:`Color`]): A list of colors.
31+
text_color (str or tuple or :obj:`Color`): Color of texts.
32+
thickness (int): Thickness of lines.
33+
font_scale (float): Font scales of texts.
34+
35+
Returns:
36+
ndarray: The image with bboxes drawn on it.
37+
"""
38+
39+
# adapt to mmcv.imshow_bboxes input format
40+
bboxes = np.split(
41+
bboxes, bboxes.shape[0], axis=0) if bboxes.shape[0] > 0 else []
42+
if not isinstance(colors, list):
43+
colors = [colors for _ in range(len(bboxes))]
44+
colors = [mmcv.color_val(c) for c in colors]
45+
assert len(bboxes) == len(colors)
46+
47+
img = mmcv.imshow_bboxes(
48+
img,
49+
bboxes,
50+
colors,
51+
top_k=-1,
52+
thickness=thickness,
53+
show=False,
54+
out_file=None)
55+
56+
if labels is not None:
57+
if not isinstance(labels, list):
58+
labels = [labels for _ in range(len(bboxes))]
59+
assert len(labels) == len(bboxes)
60+
61+
for bbox, label, color in zip(bboxes, labels, colors):
62+
if label is None:
63+
continue
64+
bbox_int = bbox[0, :4].astype(np.int32)
65+
# roughly estimate the proper font size
66+
text_size, text_baseline = cv2.getTextSize(label,
67+
cv2.FONT_HERSHEY_DUPLEX,
68+
font_scale, thickness)
69+
text_x1 = bbox_int[0]
70+
text_y1 = max(0, bbox_int[1] - text_size[1] - text_baseline)
71+
text_x2 = bbox_int[0] + text_size[0]
72+
text_y2 = text_y1 + text_size[1] + text_baseline
73+
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
74+
cv2.FILLED)
75+
cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
76+
cv2.FONT_HERSHEY_DUPLEX, font_scale,
77+
mmcv.color_val(text_color), thickness)
78+
79+
return img
80+
81+
82+
def imshow_keypoints(img,
83+
pose_result,
84+
skeleton=None,
85+
kpt_score_thr=0.3,
86+
pose_kpt_color=None,
87+
pose_link_color=None,
88+
radius=4,
89+
thickness=1,
90+
show_keypoint_weight=False):
91+
"""Draw keypoints and links on an image.
92+
93+
Args:
94+
img (str or Tensor): The image to draw poses on. If an image array
95+
is given, id will be modified in-place.
96+
pose_result (list[kpts]): The poses to draw. Each element kpts is
97+
a set of K keypoints as an Kx3 numpy.ndarray, where each
98+
keypoint is represented as x, y, score.
99+
kpt_score_thr (float, optional): Minimum score of keypoints
100+
to be shown. Default: 0.3.
101+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
102+
the keypoint will not be drawn.
103+
pose_link_color (np.array[Mx3]): Color of M links. If None, the
104+
links will not be drawn.
105+
thickness (int): Thickness of lines.
106+
"""
107+
108+
img = mmcv.imread(img)
109+
img_h, img_w, _ = img.shape
110+
111+
for kpts in pose_result:
112+
113+
kpts = np.array(kpts, copy=False)
114+
115+
# draw each point on image
116+
if pose_kpt_color is not None:
117+
assert len(pose_kpt_color) == len(kpts)
118+
119+
for kid, kpt in enumerate(kpts):
120+
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
121+
122+
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
123+
# skip the point that should not be drawn
124+
continue
125+
126+
color = tuple(int(c) for c in pose_kpt_color[kid])
127+
if show_keypoint_weight:
128+
img_copy = img.copy()
129+
cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius,
130+
color, -1)
131+
transparency = max(0, min(1, kpt_score))
132+
cv2.addWeighted(
133+
img_copy,
134+
transparency,
135+
img,
136+
1 - transparency,
137+
0,
138+
dst=img)
139+
else:
140+
cv2.circle(img, (int(x_coord), int(y_coord)), radius,
141+
color, -1)
142+
143+
# draw links
144+
if skeleton is not None and pose_link_color is not None:
145+
assert len(pose_link_color) == len(skeleton)
146+
147+
for sk_id, sk in enumerate(skeleton):
148+
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
149+
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
150+
151+
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0
152+
or pos1[1] >= img_h or pos2[0] <= 0 or pos2[0] >= img_w
153+
or pos2[1] <= 0 or pos2[1] >= img_h
154+
or kpts[sk[0], 2] < kpt_score_thr
155+
or kpts[sk[1], 2] < kpt_score_thr
156+
or pose_link_color[sk_id] is None):
157+
# skip the link that should not be drawn
158+
continue
159+
color = tuple(int(c) for c in pose_link_color[sk_id])
160+
if show_keypoint_weight:
161+
img_copy = img.copy()
162+
X = (pos1[0], pos2[0])
163+
Y = (pos1[1], pos2[1])
164+
mX = np.mean(X)
165+
mY = np.mean(Y)
166+
length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5
167+
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
168+
stickwidth = 2
169+
polygon = cv2.ellipse2Poly(
170+
(int(mX), int(mY)), (int(length / 2), int(stickwidth)),
171+
int(angle), 0, 360, 1)
172+
cv2.fillConvexPoly(img_copy, polygon, color)
173+
transparency = max(
174+
0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2])))
175+
cv2.addWeighted(
176+
img_copy,
177+
transparency,
178+
img,
179+
1 - transparency,
180+
0,
181+
dst=img)
182+
else:
183+
cv2.line(img, pos1, pos2, color, thickness=thickness)
184+
185+
return img
186+
187+
15188
@NODES.register_module()
16189
class ObjectVisualizerNode(BaseVisualizerNode):
17190
"""Visualize the bounding box and keypoints of objects.
@@ -91,12 +264,11 @@ def __init__(self,
91264
self.show_bbox = show_bbox
92265
self.show_keypoint = show_keypoint
93266
self.must_have_keypoint = must_have_keypoint
267+
self.radius = radius
268+
self.thickness = thickness
94269

95-
self.visualizer = PoseLocalVisualizer(
96-
name='webcam', radius=radius, line_width=thickness)
97-
98-
def draw(self, input_msg: FrameMessage) -> np.ndarray:
99-
canvas = input_msg.get_image()
270+
def _draw_bbox(self, canvas: np.ndarray, input_msg: FrameMessage):
271+
"""Draw object bboxes."""
100272

101273
if self.must_have_keypoint:
102274
objects = input_msg.get_objects(
@@ -107,49 +279,63 @@ def draw(self, input_msg: FrameMessage) -> np.ndarray:
107279
if not objects:
108280
return canvas
109281

110-
objects_by_label = defaultdict(list)
111-
for object in objects:
112-
objects_by_label[object['label']].append(object)
113-
114-
# draw objects of each category individually
115-
for label, objects in objects_by_label.items():
116-
dataset_meta = objects[0]['dataset_meta']
117-
dataset_meta['bbox_color'] = self.default_bbox_color.get(
118-
label, self.bbox_color)
119-
self.visualizer.set_dataset_meta(dataset_meta)
120-
121-
# assign bboxes, keypoints and other predictions to data_sample
122-
instances = InstanceData()
123-
instances.bboxes = np.stack([object['bbox'] for object in objects])
124-
instances.labels = np.array(
125-
[object['class_id'] for object in objects])
126-
if self.show_keypoint:
127-
keypoints = [
128-
object['keypoints'] for object in objects
129-
if 'keypoints' in object
130-
]
131-
if len(keypoints):
132-
instances.keypoints = np.stack(keypoints)
133-
keypoint_scores = [
134-
object['keypoint_scores'] for object in objects
135-
if 'keypoint_scores' in object
136-
]
137-
if len(keypoint_scores):
138-
instances.keypoint_scores = np.stack(keypoint_scores)
139-
data_sample = PoseDataSample()
140-
data_sample.pred_instances = instances
141-
142-
self.visualizer.add_datasample(
143-
'result',
282+
bboxes = [obj['bbox'] for obj in objects]
283+
labels = [obj.get('label', None) for obj in objects]
284+
default_color = (0, 255, 0)
285+
286+
# Get bbox colors
287+
if isinstance(self.bbox_color, dict):
288+
colors = [
289+
self.bbox_color.get(label, default_color) for label in labels
290+
]
291+
else:
292+
colors = self.bbox_color
293+
294+
imshow_bboxes(
295+
canvas,
296+
np.vstack(bboxes),
297+
labels=labels,
298+
colors=colors,
299+
text_color='white',
300+
font_scale=0.5)
301+
302+
return canvas
303+
304+
def _draw_keypoint(self, canvas: np.ndarray, input_msg: FrameMessage):
305+
"""Draw object keypoints."""
306+
objects = input_msg.get_objects(lambda x: 'pose_model_cfg' in x)
307+
308+
# return if there is no object with keypoints
309+
if not objects:
310+
return canvas
311+
312+
for model_cfg, group in groupby(objects,
313+
lambda x: x['pose_model_cfg']):
314+
dataset_info = objects[0]['dataset_meta']
315+
keypoints = [
316+
np.concatenate(
317+
(obj['keypoints'], obj['keypoint_scores'][:, None]),
318+
axis=1) for obj in group
319+
]
320+
imshow_keypoints(
144321
canvas,
145-
data_sample=data_sample,
146-
draw_gt=False,
147-
draw_heatmap=False,
148-
draw_bbox=True,
149-
show=False,
150-
wait_time=0,
151-
out_file=None,
152-
kpt_score_thr=self.kpt_thr)
153-
canvas = self.visualizer.get_image()
322+
keypoints,
323+
skeleton=dataset_info['skeleton_links'],
324+
kpt_score_thr=self.kpt_thr,
325+
pose_kpt_color=dataset_info['keypoint_colors'],
326+
pose_link_color=dataset_info['skeleton_link_colors'],
327+
radius=self.radius,
328+
thickness=self.thickness)
329+
330+
return canvas
331+
332+
def draw(self, input_msg: FrameMessage) -> np.ndarray:
333+
canvas = input_msg.get_image()
334+
335+
if self.show_bbox:
336+
canvas = self._draw_bbox(canvas, input_msg)
337+
338+
if self.show_keypoint:
339+
canvas = self._draw_keypoint(canvas, input_msg)
154340

155341
return canvas

0 commit comments

Comments
 (0)