Skip to content

Commit 0cdebc1

Browse files
committed
add docsting, ut
1 parent 4facefb commit 0cdebc1

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

mmpose/datasets/transforms/converting.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,36 @@ class KeypointConverter(BaseTransform):
2525
num_keypoints (int): The number of keypoints in target dataset.
2626
mapping (list): A list containing mapping indexes. Each element has
2727
format (source_index, target_index)
28+
29+
Example:
30+
>>> import numpy as np
31+
>>> # case 1: 1-to-1 mapping
32+
>>> # (0, 0) means target[0] = source[0]
33+
>>> self = KeypointConverter(
34+
>>> num_keypoints=3,
35+
>>> mapping=[
36+
>>> (0, 0), (1, 1), (2, 2), (3, 3)
37+
>>> ])
38+
>>> results = dict(
39+
>>> keypoints=np.arange(34).reshape(2, 3, 2),
40+
>>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
41+
>>> results = self(results)
42+
>>> assert np.equal(results['keypoints'],
43+
>>> np.arange(34).reshape(2, 3, 2)).all()
44+
>>> assert np.equal(results['keypoints_visible'],
45+
>>> np.arange(34).reshape(2, 3, 2) % 2).all()
46+
>>>
47+
>>> # case 2: 2-to-1 mapping
48+
>>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2
49+
>>> self = KeypointConverter(
50+
>>> num_keypoints=3,
51+
>>> mapping=[
52+
>>> ((1, 2), 0), (1, 1), (2, 2)
53+
>>> ])
54+
>>> results = dict(
55+
>>> keypoints=np.arange(34).reshape(2, 3, 2),
56+
>>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
57+
>>> results = self(results)
2858
"""
2959

3060
def __init__(self, num_keypoints: int,

tests/test_datasets/test_transforms/test_converting.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def setUp(self):
1313
img_shape=(240, 320), num_instances=4, with_bbox_cs=True)
1414

1515
def test_transform(self):
16+
# 1-to-1 mapping
1617
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
1718
transform = KeypointConverter(num_keypoints=5, mapping=mapping)
1819
results = transform(self.data_info.copy())
@@ -34,3 +35,39 @@ def test_transform(self):
3435
self.assertTrue(
3536
(results['keypoints_visible'][:, target_index] ==
3637
self.data_info['keypoints_visible'][:, source_index]).all())
38+
39+
# 2-to-1 mapping
40+
mapping = [((3, 5), 0), (6, 1), (16, 2), (5, 3)]
41+
transform = KeypointConverter(num_keypoints=5, mapping=mapping)
42+
results = transform(self.data_info.copy())
43+
44+
# check shape
45+
self.assertEqual(results['keypoints'].shape[0],
46+
self.data_info['keypoints'].shape[0])
47+
self.assertEqual(results['keypoints'].shape[1], 5)
48+
self.assertEqual(results['keypoints'].shape[2], 2)
49+
self.assertEqual(results['keypoints_visible'].shape[0],
50+
self.data_info['keypoints_visible'].shape[0])
51+
self.assertEqual(results['keypoints_visible'].shape[1], 5)
52+
53+
# check value
54+
for source_index, target_index in mapping:
55+
if isinstance(source_index, tuple):
56+
source_index, source_index2 = source_index
57+
self.assertTrue(
58+
(results['keypoints'][:, target_index] == 0.5 *
59+
(self.data_info['keypoints'][:, source_index] +
60+
self.data_info['keypoints'][:, source_index2])).all())
61+
self.assertTrue(
62+
(results['keypoints_visible'][:, target_index] ==
63+
self.data_info['keypoints_visible'][:, source_index] *
64+
self.data_info['keypoints_visible'][:,
65+
source_index2]).all())
66+
else:
67+
self.assertTrue(
68+
(results['keypoints'][:, target_index] ==
69+
self.data_info['keypoints'][:, source_index]).all())
70+
self.assertTrue(
71+
(results['keypoints_visible'][:, target_index] ==
72+
self.data_info['keypoints_visible'][:,
73+
source_index]).all())

0 commit comments

Comments
 (0)