@@ -13,6 +13,7 @@ def setUp(self):
1313 img_shape = (240 , 320 ), num_instances = 4 , with_bbox_cs = True )
1414
1515 def test_transform (self ):
16+ # 1-to-1 mapping
1617 mapping = [(3 , 0 ), (6 , 1 ), (16 , 2 ), (5 , 3 )]
1718 transform = KeypointConverter (num_keypoints = 5 , mapping = mapping )
1819 results = transform (self .data_info .copy ())
@@ -34,3 +35,39 @@ def test_transform(self):
3435 self .assertTrue (
3536 (results ['keypoints_visible' ][:, target_index ] ==
3637 self .data_info ['keypoints_visible' ][:, source_index ]).all ())
38+
39+ # 2-to-1 mapping
40+ mapping = [((3 , 5 ), 0 ), (6 , 1 ), (16 , 2 ), (5 , 3 )]
41+ transform = KeypointConverter (num_keypoints = 5 , mapping = mapping )
42+ results = transform (self .data_info .copy ())
43+
44+ # check shape
45+ self .assertEqual (results ['keypoints' ].shape [0 ],
46+ self .data_info ['keypoints' ].shape [0 ])
47+ self .assertEqual (results ['keypoints' ].shape [1 ], 5 )
48+ self .assertEqual (results ['keypoints' ].shape [2 ], 2 )
49+ self .assertEqual (results ['keypoints_visible' ].shape [0 ],
50+ self .data_info ['keypoints_visible' ].shape [0 ])
51+ self .assertEqual (results ['keypoints_visible' ].shape [1 ], 5 )
52+
53+ # check value
54+ for source_index , target_index in mapping :
55+ if isinstance (source_index , tuple ):
56+ source_index , source_index2 = source_index
57+ self .assertTrue (
58+ (results ['keypoints' ][:, target_index ] == 0.5 *
59+ (self .data_info ['keypoints' ][:, source_index ] +
60+ self .data_info ['keypoints' ][:, source_index2 ])).all ())
61+ self .assertTrue (
62+ (results ['keypoints_visible' ][:, target_index ] ==
63+ self .data_info ['keypoints_visible' ][:, source_index ] *
64+ self .data_info ['keypoints_visible' ][:,
65+ source_index2 ]).all ())
66+ else :
67+ self .assertTrue (
68+ (results ['keypoints' ][:, target_index ] ==
69+ self .data_info ['keypoints' ][:, source_index ]).all ())
70+ self .assertTrue (
71+ (results ['keypoints_visible' ][:, target_index ] ==
72+ self .data_info ['keypoints_visible' ][:,
73+ source_index ]).all ())
0 commit comments