Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,27 @@ A summary can be found in the [Model Zoo](https://mmpose.readthedocs.io/en/0.x/m
<details open>
<summary><b>Supported algorithms:</b></summary>

- [x] [DeepPose](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#deeppose-cvpr-2014) (CVPR'2014)
- [x] [CPM](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#cpm-cvpr-2016) (CVPR'2016)
- [x] [Hourglass](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#hourglass-eccv-2016) (ECCV'2016)
- [x] [SimpleBaseline3D](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#simplebaseline3d-iccv-2017) (ICCV'2017)
- [x] [Associative Embedding](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#associative-embedding-nips-2017) (NeurIPS'2017)
- [x] [HMR](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#hmr-cvpr-2018) (CVPR'2018)
- [x] [SimpleBaseline2D](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#simplebaseline2d-eccv-2018) (ECCV'2018)
- [x] [HRNet](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#hrnet-cvpr-2019) (CVPR'2019)
- [x] [VideoPose3D](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#videopose3d-cvpr-2019) (CVPR'2019)
- [x] [HRNetv2](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#hrnetv2-tpami-2019) (TPAMI'2019)
- [x] [MSPN](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#mspn-arxiv-2019) (ArXiv'2019)
- [x] [SCNet](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#scnet-cvpr-2020) (CVPR'2020)
- [x] [HigherHRNet](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#higherhrnet-cvpr-2020) (CVPR'2020)
- [x] [RSN](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#rsn-eccv-2020) (ECCV'2020)
- [x] [InterNet](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#internet-eccv-2020) (ECCV'2020)
- [x] [VoxelPose](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#voxelpose-eccv-2020) (ECCV'2020)
- [x] [LiteHRNet](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#litehrnet-cvpr-2021) (CVPR'2021)
- [x] [ViPNAS](https://mmpose.readthedocs.io/en/0.x/papers/backbones.html#vipnas-cvpr-2021) (CVPR'2021)
- [x] [DEKR](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#dekr-cvpr-2021) (CVPR'2021)
- [x] [CID](https://mmpose.readthedocs.io/en/0.x/papers/algorithms.html#cid-cvpr-2022) (CVPR'2022)
- [x] [DeepPose](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#deeppose-cvpr-2014) (CVPR'2014)
- [x] [CPM](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#cpm-cvpr-2016) (CVPR'2016)
- [x] [Hourglass](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hourglass-eccv-2016) (ECCV'2016)
- [x] [SimpleBaseline3D](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#simplebaseline3d-iccv-2017) (ICCV'2017)
- [x] [Associative Embedding](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#associative-embedding-nips-2017) (NeurIPS'2017)
- [x] [HMR](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#hmr-cvpr-2018) (CVPR'2018)
- [x] [SimpleBaseline2D](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#simplebaseline2d-eccv-2018) (ECCV'2018)
- [x] [HRNet](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) (CVPR'2019)
- [x] [VideoPose3D](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#videopose3d-cvpr-2019) (CVPR'2019)
- [x] [HRNetv2](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnetv2-tpami-2019) (TPAMI'2019)
- [x] [MSPN](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) (ArXiv'2019)
- [x] [SCNet](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#scnet-cvpr-2020) (CVPR'2020)
- [x] [HigherHRNet](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#higherhrnet-cvpr-2020) (CVPR'2020)
- [x] [RSN](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#rsn-eccv-2020) (ECCV'2020)
- [x] [InterNet](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#internet-eccv-2020) (ECCV'2020)
- [x] [VoxelPose](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#voxelpose-eccv-2020) (ECCV'2020)
- [x] [LiteHRNet](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) (CVPR'2021)
- [x] [ViPNAS](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#vipnas-cvpr-2021) (CVPR'2021)
- [x] [DEKR](https://mmpose.readthedocs.io/zh_CN/latest/papers/algorithms.html#dekr-cvpr-2021) (CVPR'2021)
- [x] [CID](https://mmpose.readthedocs.io/zh_CN/latest/papers/algorithms.html#cid-cvpr-2022) (CVPR'2022)
- [x] [ViTPose](https://mmpose.readthedocs.io/en/latest/papers/algorithms.html#vitpose-neurips-2022) (Neurips'2022)

</details>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')

optimizer = dict(
type='AdamW',
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=12,
layer_decay_rate=0.75,
))

optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))

# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='VisionTransformer',
img_size=(256, 192),
patch_size=16,
embed_dims=768,
# Optional in train
padding=2,
num_layers=12,
num_heads=12,
mlp_ratio=4,
drop_path_rate=0.3,
final_norm=True,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=768,
num_deconv_layers=2,
num_deconv_filters=(256, 256),
num_deconv_kernels=(4, 4),
extra=dict(final_conv_kernel=1, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=False,
target_type=target_type,
modulate_kernel=11,
use_udp=True))

data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownGetBboxCenterScale', padding=1.25),
dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]

val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownGetBboxCenterScale', padding=1.25),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]

test_pipeline = val_pipeline

data_root = 'data/coco'
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2204.12484">ViTPose (Neurips'2022)</a></summary>

```bibtex
@inproceedings{
xu2022vitpose,
title={Vi{TP}ose: Simple Vision Transformer Baselines for Human Pose Estimation},
author={Yufei Xu and Jing Zhang and Qiming Zhang and Dacheng Tao},
booktitle={Advances in Neural Information Processing Systems},
year={2022},
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48">COCO (ECCV'2014)</a></summary>

```bibtex
@inproceedings{lin2014microsoft,
title={Microsoft coco: Common objects in context},
author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
booktitle={European conference on computer vision},
pages={740--755},
year={2014},
organization={Springer}
}
```

</details>

The backbone models are pre-trained using MAE. The small-size pre-trained backbone can be found in [link](https://github.com/ViTAE-Transformer/ViTPose). The base, large, and huge pre-trained backbones can be found in [link](https://github.com/facebookresearch/mae).

Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 dataset

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :--------------------------------------------------------------------------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------: | :-------: |
| [ViTPose-S](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_small_coco_256x192.py) | 256x192 | 0.738 | 0.903 | 0.813 | 0.792 | 0.940 | [ckpt](<>) | [log](<>) |
| [ViTPose-B](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_base_coco_256x192.py) | 256x192 | 0.758 | 0.907 | 0.832 | 0.811 | 0.946 | [ckpt](<>) | [log](<>) |
| [ViTPose-L](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_large_coco_256x192.py) | 256x192 | 0.783 | 0.914 | 0.852 | 0.835 | 0.953 | [ckpt](<>) | [log](<>) |
| [ViTPose-H](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_huge_coco_256x192.py) | 256x192 | 0.791 | 0.917 | 0.857 | 0.841 | 0.954 | [ckpt](<>) | [log](<>) |
| [ViTPose-Simple-S](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_simple_small_coco_256x192.py) | 256x192 | 0.735 | 0.900 | 0.811 | 0.789 | 0.940 | [ckpt](<>) | [log](<>) |
| [ViTPose-Simple-B](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_simple_base_coco_256x192.py) | 256x192 | 0.755 | 0.906 | 0.829 | 0.809 | 0.946 | [ckpt](<>) | [log](<>) |
| [ViTPose-Simple-L](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_simple_large_coco_256x192.py) | 256x192 | 0.782 | 0.914 | 0.853 | 0.834 | 0.953 | [ckpt](<>) | [log](<>) |
| [ViTPose-Simple-H](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitpose_simple_huge_coco_256x192.py) | 256x192 | 0.789 | 0.916 | 0.856 | 0.840 | 0.954 | [ckpt](<>) | [log](<>) |
Loading