Skip to content

Commit f42ec26

Browse files
authored
Merge 304f011 into 3eca326
2 parents 3eca326 + 304f011 commit f42ec26

File tree

13 files changed

+591
-6
lines changed

13 files changed

+591
-6
lines changed

docs/en/user_guides/train_and_test.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,121 @@ Here are the environment variables that can be used to configure the slurm job.
390390
| `GPUS_PER_NODE` | The number of GPUs to be allocated per node. Defaults to 8. |
391391
| `CPUS_PER_TASK` | The number of CPUs to be allocated per task (Usually one GPU corresponds to one task). Defaults to 5. |
392392
| `SRUN_ARGS` | The other arguments of `srun`. Available options can be found [here](https://slurm.schedmd.com/srun.html). |
393+
394+
## Custom Testing Features
395+
396+
### Test with Custom Metrics
397+
398+
If you're looking to assess models using unique metrics not already supported by MMPose, you'll need to code these metrics yourself and include them in your config file. For guidance on how to accomplish this, check out our [customized evaluation guide](https://mmpose.readthedocs.io/en/latest/advanced_guides/customize_evaluation.html).
399+
400+
### Evaluating Across Multiple Datasets
401+
402+
MMPose offers a handy tool known as `MultiDatasetEvaluator` for streamlined assessment across multiple datasets. Setting up this evaluator in your config file is a breeze. Below is a quick example demonstrating how to evaluate a model using both the COCO and AIC datasets:
403+
404+
```python
405+
# Set up validation datasets
406+
coco_val = dict(type='CocoDataset', ...)
407+
aic_val = dict(type='AicDataset', ...)
408+
val_dataset = dict(
409+
type='CombinedDataset',
410+
datasets=[coco_val, aic_val],
411+
pipeline=val_pipeline,
412+
...)
413+
414+
# configurate the evaluator
415+
val_evaluator = dict(
416+
type='MultiDatasetEvaluator',
417+
metrics=[ # metrics for each dataset
418+
dict(type='CocoMetric',
419+
ann_file='data/coco/annotations/person_keypoints_val2017.json'),
420+
dict(type='CocoMetric',
421+
ann_file='data/aic/annotations/aic_val.json',
422+
use_area=False,
423+
prefix='aic')
424+
],
425+
# the number and order of datasets must align with metrics
426+
datasets=[coco_val, aic_val],
427+
)
428+
```
429+
430+
Keep in mind that different datasets, like COCO and AIC, have various keypoint definitions. Yet, the model's output keypoints are standardized. This results in a discrepancy between the model outputs and the actual ground truth. To address this, you can employ `KeypointConverter` to align the keypoint configurations between different datasets. Here’s a full example that shows how to leverage `KeypointConverter` to align AIC keypoints with COCO keypoints:
431+
432+
```python
433+
aic_to_coco_converter = dict(
434+
type='KeypointConverter',
435+
num_keypoints=17,
436+
mapping=[
437+
(0, 6),
438+
(1, 8),
439+
(2, 10),
440+
(3, 5),
441+
(4, 7),
442+
(5, 9),
443+
(6, 12),
444+
(7, 14),
445+
(8, 16),
446+
(9, 11),
447+
(10, 13),
448+
(11, 15),
449+
])
450+
451+
# val datasets
452+
coco_val = dict(
453+
type='CocoDataset',
454+
data_root='data/coco/',
455+
data_mode='topdown',
456+
ann_file='annotations/person_keypoints_val2017.json',
457+
bbox_file='data/coco/person_detection_results/'
458+
'COCO_val2017_detections_AP_H_56_person.json',
459+
data_prefix=dict(img='val2017/'),
460+
test_mode=True,
461+
pipeline=[],
462+
)
463+
464+
aic_val = dict(
465+
type='AicDataset',
466+
data_root='data/aic/',
467+
data_mode=data_mode,
468+
ann_file='annotations/aic_val.json',
469+
data_prefix=dict(img='ai_challenger_keypoint_validation_20170911/'
470+
'keypoint_validation_images_20170911/'),
471+
test_mode=True,
472+
pipeline=[],
473+
)
474+
475+
val_dataset = dict(
476+
type='CombinedDataset',
477+
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
478+
datasets=[coco_val, aic_val],
479+
pipeline=val_pipeline,
480+
test_mode=True,
481+
)
482+
483+
val_dataloader = dict(
484+
batch_size=32,
485+
num_workers=2,
486+
persistent_workers=True,
487+
drop_last=False,
488+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
489+
dataset=val_dataset)
490+
491+
test_dataloader = val_dataloader
492+
493+
val_evaluator = dict(
494+
type='MultiDatasetEvaluator',
495+
metrics=[
496+
dict(type='CocoMetric',
497+
ann_file=data_root + 'annotations/person_keypoints_val2017.json'),
498+
dict(type='CocoMetric',
499+
ann_file='data/aic/annotations/aic_val.json',
500+
use_area=False,
501+
gt_converter=aic_to_coco_converter,
502+
prefix='aic')
503+
],
504+
datasets=val_dataset['datasets'],
505+
)
506+
507+
test_evaluator = val_evaluator
508+
```
509+
510+
For further clarification on converting AIC keypoints to COCO keypoints, please consult [this guide](https://mmpose.readthedocs.io/en/latest/user_guides/mixed_datasets.html#merge-aic-into-coco).

docs/zh_cn/user_guides/train_and_test.md

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,123 @@ NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR bash tools/dist_
376376
| `GPUS_PER_NODE` | 每台机器使用的 GPU 总数,默认为 8 |
377377
| `CPUS_PER_TASK` | 每个任务分配的 CPU 总数(通常为 1 张 GPU 对应 1 个任务进程),默认为 5 |
378378
| `SRUN_ARGS` | `srun` 的其他参数,可选项见 [这里](https://slurm.schedmd.com/srun.html). |
379+
380+
## 自定义测试
381+
382+
### 用自定义度量进行测试
383+
384+
如果您希望使用 MMPose 中尚未支持的独特度量来评估模型,您将需要自己编写这些度量并将它们包含在您的配置文件中。关于如何实现这一点的指导,请查看我们的 [自定义评估指南](https://mmpose.readthedocs.io/zh_CN/dev-1.x/advanced_guides/customize_evaluation.html)
385+
386+
### 在多个数据集上进行评估
387+
388+
MMPose 提供了一个名为 `MultiDatasetEvaluator` 的便捷工具,用于在多个数据集上进行简化评估。在配置文件中设置此评估器非常简单。下面是一个快速示例,演示如何使用 COCO 和 AIC 数据集评估模型:
389+
390+
```python
391+
# 设置验证数据集
392+
coco_val = dict(type='CocoDataset', ...)
393+
394+
aic_val = dict(type='AicDataset', ...)
395+
396+
val_dataset = dict(
397+
type='CombinedDataset',
398+
datasets=[coco_val, aic_val],
399+
pipeline=val_pipeline,
400+
...)
401+
402+
# 配置评估器
403+
val_evaluator = dict(
404+
type='MultiDatasetEvaluator',
405+
metrics=[ # 为每个数据集配置度量
406+
dict(type='CocoMetric',
407+
ann_file='data/coco/annotations/person_keypoints_val2017.json'),
408+
dict(type='CocoMetric',
409+
ann_file='data/aic/annotations/aic_val.json',
410+
use_area=False,
411+
prefix='aic')
412+
],
413+
# 数据集个数和顺序与度量必须匹配
414+
datasets=[coco_val, aic_val],
415+
)
416+
```
417+
418+
同的数据集(如 COCO 和 AIC)具有不同的关键点定义。然而,模型的输出关键点是标准化的。这导致了模型输出与真值之间关键点顺序的差异。为解决这一问题,您可以使用 `KeypointConverter` 来对齐不同数据集之间的关键点顺序。下面是一个完整示例,展示了如何利用 `KeypointConverter` 来对齐 AIC 关键点与 COCO 关键点:
419+
420+
```python
421+
aic_to_coco_converter = dict(
422+
type='KeypointConverter',
423+
num_keypoints=17,
424+
mapping=[
425+
(0, 6),
426+
(1, 8),
427+
(2, 10),
428+
(3, 5),
429+
(4, 7),
430+
(5, 9),
431+
(6, 12),
432+
(7, 14),
433+
(8, 16),
434+
(9, 11),
435+
(10, 13),
436+
(11, 15),
437+
])
438+
439+
# val datasets
440+
coco_val = dict(
441+
type='CocoDataset',
442+
data_root='data/coco/',
443+
data_mode='topdown',
444+
ann_file='annotations/person_keypoints_val2017.json',
445+
bbox_file='data/coco/person_detection_results/'
446+
'COCO_val2017_detections_AP_H_56_person.json',
447+
data_prefix=dict(img='val2017/'),
448+
test_mode=True,
449+
pipeline=[],
450+
)
451+
452+
aic_val = dict(
453+
type='AicDataset',
454+
data_root='data/aic/',
455+
data_mode=data_mode,
456+
ann_file='annotations/aic_val.json',
457+
data_prefix=dict(img='ai_challenger_keypoint_validation_20170911/'
458+
'keypoint_validation_images_20170911/'),
459+
test_mode=True,
460+
pipeline=[],
461+
)
462+
463+
val_dataset = dict(
464+
type='CombinedDataset',
465+
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
466+
datasets=[coco_val, aic_val],
467+
pipeline=val_pipeline,
468+
test_mode=True,
469+
)
470+
471+
val_dataloader = dict(
472+
batch_size=32,
473+
num_workers=2,
474+
persistent_workers=True,
475+
drop_last=False,
476+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
477+
dataset=val_dataset)
478+
479+
test_dataloader = val_dataloader
480+
481+
val_evaluator = dict(
482+
type='MultiDatasetEvaluator',
483+
metrics=[
484+
dict(type='CocoMetric',
485+
ann_file=data_root + 'annotations/person_keypoints_val2017.json'),
486+
dict(type='CocoMetric',
487+
ann_file='data/aic/annotations/aic_val.json',
488+
use_area=False,
489+
gt_converter=aic_to_coco_converter,
490+
prefix='aic')
491+
],
492+
datasets=val_dataset['datasets'],
493+
)
494+
495+
test_evaluator = val_evaluator
496+
```
497+
498+
如需进一步了解如何将 AIC 关键点转换为 COCO 关键点,请查阅 [该指南](https://mmpose.readthedocs.io/zh_CN/dev-1.x/user_guides/mixed_datasets.html#aic-coco)

mmpose/datasets/datasets/base/base_coco_style_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def get_data_info(self, idx: int) -> dict:
166166

167167
# Add metainfo items that are required in the pipeline and the model
168168
metainfo_keys = [
169-
'upper_body_ids', 'lower_body_ids', 'flip_pairs',
169+
'dataset_name', 'upper_body_ids', 'lower_body_ids', 'flip_pairs',
170170
'dataset_keypoint_weights', 'flip_indices', 'skeleton_links'
171171
]
172172

mmpose/datasets/transforms/converting.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, num_keypoints: int,
8383
self.source_index2 = src2
8484

8585
self.source_index = src1
86-
self.target_index = target_index
86+
self.target_index = list(target_index)
8787
self.interpolation = interpolation
8888

8989
def transform(self, results: dict) -> dict:
@@ -122,6 +122,44 @@ def transform(self, results: dict) -> dict:
122122
[keypoints_visible, keypoints_visible_weights], axis=2)
123123
return results
124124

125+
def transform_sigmas(self, sigmas: Union[List, np.ndarray]):
126+
"""Transforms the sigmas based on the mapping."""
127+
list_input = False
128+
if isinstance(sigmas, list):
129+
sigmas = np.array(sigmas)
130+
list_input = True
131+
132+
new_sigmas = np.ones(self.num_keypoints, dtype=sigmas.dtype)
133+
new_sigmas[self.target_index] = sigmas[self.source_index]
134+
135+
if list_input:
136+
new_sigmas = new_sigmas.tolist()
137+
138+
return new_sigmas
139+
140+
def transform_ann(self, ann_info: Union[dict, list]):
141+
"""Transforms the annotations based on the mapping."""
142+
143+
list_input = True
144+
if not isinstance(ann_info, list):
145+
ann_info = [ann_info]
146+
list_input = False
147+
148+
for ann in ann_info:
149+
if 'keypoints' in ann:
150+
keypoints = np.array(ann['keypoints']).reshape(-1, 3)
151+
new_keypoints = np.zeros((self.num_keypoints, 3),
152+
dtype=keypoints.dtype)
153+
new_keypoints[self.target_index] = keypoints[self.source_index]
154+
ann['keypoints'] = new_keypoints.reshape(-1).tolist()
155+
if 'num_keypoints' in ann:
156+
ann['num_keypoints'] = self.num_keypoints
157+
158+
if not list_input:
159+
ann_info = ann_info[0]
160+
161+
return ann_info
162+
125163
def __repr__(self) -> str:
126164
"""print the basic information of the transform.
127165

mmpose/datasets/transforms/formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(self,
148148
'crowd_index', 'ori_shape', 'img_shape',
149149
'input_size', 'input_center', 'input_scale',
150150
'flip', 'flip_direction', 'flip_indices',
151-
'raw_ann_info'),
151+
'raw_ann_info', 'dataset_name'),
152152
pack_transformed=False):
153153
self.meta_keys = meta_keys
154154
self.pack_transformed = pack_transformed

mmpose/evaluation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .evaluators import * # noqa: F401,F403
23
from .functional import * # noqa: F401,F403
34
from .metrics import * # noqa: F401,F403
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .mutli_dataset_evaluator import MultiDatasetEvaluator
3+
4+
__all__ = ['MultiDatasetEvaluator']

0 commit comments

Comments
 (0)