Skip to content

Commit de58466

Browse files
authored
[Feature] Support YOLO-Pose (#2020)
1 parent 45486ea commit de58466

21 files changed

+1765
-13
lines changed

demo/inferencer_demo.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def parse_args():
3737
nargs='+',
3838
default=None,
3939
help='Category id for detection model.')
40+
parser.add_argument(
41+
'--scope',
42+
type=str,
43+
default='mmpose',
44+
help='Scope where modules are defined.')
4045
parser.add_argument(
4146
'--device',
4247
type=str,
@@ -83,8 +88,8 @@ def parse_args():
8388
call_args = vars(parser.parse_args())
8489

8590
init_kws = [
86-
'pose2d', 'pose2d_weights', 'device', 'det_model', 'det_weights',
87-
'det_cat_ids'
91+
'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model',
92+
'det_weights', 'det_cat_ids'
8893
]
8994
init_args = {}
9095
for init_kw in init_kws:

projects/README.md

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,22 @@ We also provide some documentation listed below to help you get started:
3030

3131
## Project List
3232

33-
- [MMPose4AIGC](./mmpose4aigc)
33+
- **[:zap:RTMPose](./rtmpose)**: Real-Time Multi-Person Pose Estimation toolkit based on MMPose
3434

35-
This project will demonstrate how to use MMPose to generate skeleton images for pose guided AI image generation.
35+
<div align="center">
36+
<img src="https://user-images.githubusercontent.com/15977946/225229448-36ff568d-a723-4248-bb19-2df4044ff8e8.png" width=800 height=200 />
37+
</div><br/>
3638

37-
<div align=center>
38-
<img src="https://user-images.githubusercontent.com/13503330/222403836-c65ba905-4bdd-4a44-834c-ff8d5959649d.png" width=1000 height=200/>
39-
</div>
39+
- **[:art:MMPose4AIGC](./mmpose4aigc)**: Guide AI image generation with MMPose
4040

41-
- [RTMPose](./rtmpose)
41+
<div align=center>
42+
<img src="https://user-images.githubusercontent.com/13503330/222403836-c65ba905-4bdd-4a44-834c-ff8d5959649d.png" width="800"/>
43+
</div><br/>
4244

43-
Real-Time Multi-Person Pose Estimation toolkit based on MMPose
45+
- **[:bulb:YOLOX-Pose](./yolox-pose)**: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss
4446

45-
<div align="center">
46-
<img width=1000 height=200 src="https://user-images.githubusercontent.com/15977946/225229448-36ff568d-a723-4248-bb19-2df4044ff8e8.png"/>
47-
</div>
47+
<div align=center>
48+
<img src="https://user-images.githubusercontent.com/26127467/226655503-3cee746e-6e42-40be-82ae-6e7cae2a4c7e.jpg" width="800" style="width: 800px; height: 200px; object-fit: cover"/>
49+
</div><br/>
4850

49-
- **And we can't wait to see what you contribute next!**
51+
- **What's next? Join the rank of <span style="color:blue"> *MMPose contributors* </span> by creating a new project**!

projects/yolox-pose/README.md

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# YOLOX-Pose
2+
3+
This project implements a YOLOX-based human pose estimator, utilizing the approach outlined in **YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss** (CVPRW 2022). This pose estimator is lightweight and quick, making it well-suited for crowded scenes.
4+
5+
<img src="https://user-images.githubusercontent.com/26127467/226655503-3cee746e-6e42-40be-82ae-6e7cae2a4c7e.jpg" alt><br>
6+
7+
## Usage
8+
9+
### Prerequisites
10+
11+
- Python 3.7 or higher
12+
- PyTorch 1.6 or higher
13+
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.6.0 or higher
14+
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4 or higher
15+
- [MMDetection](https://github.com/open-mmlab/mmdetection) v3.0.0rc6 or higher
16+
- [MMYOLO](https://github.com/open-mmlab/mmyolo) v0.5.0 or higher
17+
- [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0rc1 or higher
18+
19+
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `yolox-pose/` root directory, run the following line to add the current directory to `PYTHONPATH`:
20+
21+
```shell
22+
export PYTHONPATH=`pwd`:$PYTHONPATH
23+
```
24+
25+
### Inference
26+
27+
Users can apply YOLOX-Pose models to estimate human poses using the inferencer found in the MMPose core package. Use the command below:
28+
29+
```shell
30+
python demo/inferencer_demo.py $INPUTS \
31+
--pose2d $CONFIG --pose2d-weights $CHECKPOINT --scope mmyolo \
32+
[--show] [--vis-out-dir $VIS_OUT_DIR] [--pred-out-dir $PRED_OUT_DIR]
33+
```
34+
35+
For more information on using the inferencer, please see [this document](https://mmpose.readthedocs.io/en/1.x/user_guides/inference.html#out-of-the-box-inferencer).
36+
37+
Here's an example code:
38+
39+
```shell
40+
python demo/inferencer_demo.py ../../tests/data/coco/000000000785.jpg \
41+
--pose2d configs/yolox-pose_s_8xb32-300e_coco.py \
42+
--pose2d-weights https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco-9f5e3924_20230321.pth \
43+
--scope mmyolo --vis-out-dir vis_results
44+
```
45+
46+
This will create an output image `vis_results/000000000785.jpg`, which appears like:
47+
48+
<img src="https://user-images.githubusercontent.com/26127467/226552585-19b91294-9751-4599-98e7-5dae071a1761.jpg" height="360px" alt><br>
49+
50+
### Training & Testing
51+
52+
#### Data Preparation
53+
54+
Prepare the COCO dataset according to the [instruction](https://mmpose.readthedocs.io/en/1.x/dataset_zoo/2d_body_keypoint.html#coco).
55+
56+
#### Commands
57+
58+
**To train with multiple GPUs:**
59+
60+
```shell
61+
bash tools/dist_train.sh $CONFIG 8 --amp
62+
```
63+
64+
**To train with slurm:**
65+
66+
```shell
67+
bash tools/slurm_train.sh $PARTITION $JOBNAME $CONFIG $WORKDIR --amp
68+
```
69+
70+
**To test with single GPU:**
71+
72+
```shell
73+
python tools/test.py $CONFIG $CHECKPOINT
74+
```
75+
76+
**To test with multiple GPUs:**
77+
78+
```shell
79+
bash tools/dist_test.sh $CONFIG $CHECKPOINT 8
80+
```
81+
82+
**To test with multiple GPUs by slurm:**
83+
84+
```shell
85+
bash tools/slurm_test.sh $PARTITION $JOBNAME $CONFIG $CHECKPOINT
86+
```
87+
88+
### Results
89+
90+
Results on COCO val2017
91+
92+
| Model | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | Download |
93+
| :-------------------------------------------------------------: | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :----------------------------------------------------------------------: |
94+
| [YOLOX-tiny-Pose](./configs/yolox-pose_tiny_4xb64-300e_coco.py) | 640 | 0.477 | 0.756 | 0.506 | 0.547 | 0.802 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco-c47dd83b_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco_20230321.json) |
95+
| [YOLOX-s-Pose](./configs/yolox-pose_s_8xb32-300e_coco.py) | 640 | 0.595 | 0.836 | 0.653 | 0.658 | 0.878 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco-9f5e3924_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco_20230321.json) |
96+
| [YOLOX-m-Pose](./configs/yolox-pose_m_4xb64-300e_coco.py) | 640 | 0.659 | 0.870 | 0.729 | 0.713 | 0.903 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_m_4xb64-300e_coco-cbd11d30_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_m_4xb64-300e_coco_20230321.json) |
97+
| [YOLOX-l-Pose](./configs/yolox-pose_l_4xb64-300e_coco.py) | 640 | 0.679 | 0.882 | 0.749 | 0.733 | 0.911 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_l_4xb64-300e_coco-122e4cf8_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_l_4xb64-300e_coco_20230321.json) |
98+
99+
We have only trained models with an input size of 640, as we couldn't replicate the performance enhancement mentioned in the paper when increasing the input size from 640 to 960. We warmly welcome any contributions if you can successfully reproduce the results from the paper!
100+
101+
## Citation
102+
103+
If this project benefits your work, please kindly consider citing the original paper:
104+
105+
```bibtex
106+
@inproceedings{maji2022yolo,
107+
title={YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss},
108+
author={Maji, Debapriya and Nagori, Soyeb and Mathew, Manu and Poddar, Deepak},
109+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
110+
pages={2637--2646},
111+
year={2022}
112+
}
113+
```
114+
115+
Additionally, please cite our work as well:
116+
117+
```bibtex
118+
@misc{mmpose2020,
119+
title={OpenMMLab Pose Estimation Toolbox and Benchmark},
120+
author={MMPose Contributors},
121+
howpublished = {\url{https://github.com/open-mmlab/mmpose}},
122+
year={2020}
123+
}
124+
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../../configs/_base_/datasets
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
default_scope = 'mmyolo'
2+
custom_imports = dict(imports=['models', 'datasets'])
3+
4+
# hooks
5+
default_hooks = dict(
6+
timer=dict(type='IterTimerHook'),
7+
logger=dict(type='LoggerHook', interval=50),
8+
param_scheduler=dict(type='ParamSchedulerHook'),
9+
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3),
10+
sampler_seed=dict(type='DistSamplerSeedHook'),
11+
visualization=dict(type='mmpose.PoseVisualizationHook', enable=False),
12+
)
13+
14+
# multi-processing backend
15+
env_cfg = dict(
16+
cudnn_benchmark=False,
17+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
18+
dist_cfg=dict(backend='nccl'),
19+
)
20+
21+
# visualizer
22+
vis_backends = [dict(type='LocalVisBackend')]
23+
visualizer = dict(
24+
type='mmpose.PoseLocalVisualizer',
25+
vis_backends=vis_backends,
26+
name='visualizer')
27+
28+
# logger
29+
log_processor = dict(
30+
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
31+
log_level = 'INFO'
32+
load_from = None
33+
resume = False
34+
35+
# file I/O backend
36+
file_client_args = dict(backend='disk')
37+
38+
# training/validation/testing progress
39+
train_cfg = dict()
40+
val_cfg = dict(type='ValLoop')
41+
test_cfg = dict(type='TestLoop')
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = ['./yolox-pose_s_8xb32-300e_coco.py']
2+
3+
# model settings
4+
model = dict(
5+
init_cfg=dict(checkpoint='https://download.openmmlab.com/mmyolo/v0/yolox/'
6+
'yolox_l_fast_8xb8-300e_coco/yolox_l_fast_8xb8-300e_'
7+
'coco_20230213_160715-c731eb1c.pth'),
8+
backbone=dict(
9+
deepen_factor=1.0,
10+
widen_factor=1.0,
11+
),
12+
neck=dict(
13+
deepen_factor=1.0,
14+
widen_factor=1.0,
15+
),
16+
bbox_head=dict(head_module=dict(widen_factor=1.0)))
17+
18+
train_dataloader = dict(batch_size=64)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = ['./yolox-pose_s_8xb32-300e_coco.py']
2+
3+
# model settings
4+
model = dict(
5+
init_cfg=dict(checkpoint='https://download.openmmlab.com/mmyolo/v0/yolox/'
6+
'yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32'
7+
'-300e-rtmdet-hyp_coco_20230210_144328-e657e182.pth'),
8+
backbone=dict(
9+
deepen_factor=0.67,
10+
widen_factor=0.75,
11+
),
12+
neck=dict(
13+
deepen_factor=0.67,
14+
widen_factor=0.75,
15+
),
16+
bbox_head=dict(head_module=dict(widen_factor=0.75)))
17+
18+
train_dataloader = dict(batch_size=64)

0 commit comments

Comments
 (0)