|
| 1 | +_base_ = ['../../../_base_/default_runtime.py'] |
| 2 | + |
| 3 | +# runtime |
| 4 | +max_epochs = 120 |
| 5 | +stage2_num_epochs = 10 |
| 6 | +base_lr = 4e-3 |
| 7 | + |
| 8 | +train_cfg = dict(max_epochs=max_epochs, val_interval=1) |
| 9 | +randomness = dict(seed=21) |
| 10 | + |
| 11 | +# optimizer |
| 12 | +optim_wrapper = dict( |
| 13 | + type='OptimWrapper', |
| 14 | + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), |
| 15 | + paramwise_cfg=dict( |
| 16 | + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) |
| 17 | + |
| 18 | +# learning rate |
| 19 | +param_scheduler = [ |
| 20 | + dict( |
| 21 | + type='LinearLR', |
| 22 | + start_factor=1.0e-5, |
| 23 | + by_epoch=False, |
| 24 | + begin=0, |
| 25 | + end=1000), |
| 26 | + dict( |
| 27 | + # use cosine lr from 150 to 300 epoch |
| 28 | + type='CosineAnnealingLR', |
| 29 | + eta_min=base_lr * 0.05, |
| 30 | + begin=max_epochs // 2, |
| 31 | + end=max_epochs, |
| 32 | + T_max=max_epochs // 2, |
| 33 | + by_epoch=True, |
| 34 | + convert_to_iter_based=True), |
| 35 | +] |
| 36 | + |
| 37 | +# automatically scaling LR based on the actual training batch size |
| 38 | +auto_scale_lr = dict(base_batch_size=512) |
| 39 | + |
| 40 | +# codec settings |
| 41 | +codec = dict( |
| 42 | + type='SimCCLabel', |
| 43 | + input_size=(256, 256), |
| 44 | + sigma=(5.66, 5.66), |
| 45 | + simcc_split_ratio=2.0, |
| 46 | + normalize=False, |
| 47 | + use_dark=False) |
| 48 | + |
| 49 | +# model settings |
| 50 | +model = dict( |
| 51 | + type='TopdownPoseEstimator', |
| 52 | + data_preprocessor=dict( |
| 53 | + type='PoseDataPreprocessor', |
| 54 | + mean=[123.675, 116.28, 103.53], |
| 55 | + std=[58.395, 57.12, 57.375], |
| 56 | + bgr_to_rgb=True), |
| 57 | + backbone=dict( |
| 58 | + _scope_='mmdet', |
| 59 | + type='CSPNeXt', |
| 60 | + arch='P5', |
| 61 | + expand_ratio=0.5, |
| 62 | + deepen_factor=0.67, |
| 63 | + widen_factor=0.75, |
| 64 | + out_indices=(4, ), |
| 65 | + channel_attention=True, |
| 66 | + norm_cfg=dict(type='SyncBN'), |
| 67 | + act_cfg=dict(type='SiLU'), |
| 68 | + init_cfg=dict( |
| 69 | + type='Pretrained', |
| 70 | + prefix='backbone.', |
| 71 | + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' |
| 72 | + 'rtmposev1/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' # noqa |
| 73 | + )), |
| 74 | + head=dict( |
| 75 | + type='RTMCCHead', |
| 76 | + in_channels=768, |
| 77 | + out_channels=106, |
| 78 | + input_size=codec['input_size'], |
| 79 | + in_featuremap_size=(8, 8), |
| 80 | + simcc_split_ratio=codec['simcc_split_ratio'], |
| 81 | + final_layer_kernel_size=7, |
| 82 | + gau_cfg=dict( |
| 83 | + hidden_dims=256, |
| 84 | + s=128, |
| 85 | + expansion_factor=2, |
| 86 | + dropout_rate=0., |
| 87 | + drop_path=0., |
| 88 | + act_fn='SiLU', |
| 89 | + use_rel_bias=False, |
| 90 | + pos_enc=False), |
| 91 | + loss=dict( |
| 92 | + type='KLDiscretLoss', |
| 93 | + use_target_weight=True, |
| 94 | + beta=10., |
| 95 | + label_softmax=True), |
| 96 | + decoder=codec), |
| 97 | + test_cfg=dict(flip_test=True, )) |
| 98 | + |
| 99 | +# base dataset settings |
| 100 | +dataset_type = 'LapaDataset' |
| 101 | +data_mode = 'topdown' |
| 102 | +data_root = 'data/LaPa/' |
| 103 | + |
| 104 | +backend_args = dict(backend='local') |
| 105 | +# backend_args = dict( |
| 106 | +# backend='petrel', |
| 107 | +# path_mapping=dict({ |
| 108 | +# f'{data_root}': 's3://openmmlab/datasets/pose/LaPa/', |
| 109 | +# f'{data_root}': 's3://openmmlab/datasets/pose/LaPa/' |
| 110 | +# })) |
| 111 | + |
| 112 | +# pipelines |
| 113 | +train_pipeline = [ |
| 114 | + dict(type='LoadImage', backend_args=backend_args), |
| 115 | + dict(type='GetBBoxCenterScale'), |
| 116 | + dict(type='RandomFlip', direction='horizontal'), |
| 117 | + dict(type='RandomHalfBody'), |
| 118 | + dict( |
| 119 | + type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=80), |
| 120 | + dict(type='TopdownAffine', input_size=codec['input_size']), |
| 121 | + dict(type='mmdet.YOLOXHSVRandomAug'), |
| 122 | + dict(type='PhotometricDistortion'), |
| 123 | + dict( |
| 124 | + type='Albumentation', |
| 125 | + transforms=[ |
| 126 | + dict(type='Blur', p=0.2), |
| 127 | + dict(type='MedianBlur', p=0.2), |
| 128 | + dict( |
| 129 | + type='CoarseDropout', |
| 130 | + max_holes=1, |
| 131 | + max_height=0.4, |
| 132 | + max_width=0.4, |
| 133 | + min_holes=1, |
| 134 | + min_height=0.2, |
| 135 | + min_width=0.2, |
| 136 | + p=1.0), |
| 137 | + ]), |
| 138 | + dict(type='GenerateTarget', encoder=codec), |
| 139 | + dict(type='PackPoseInputs') |
| 140 | +] |
| 141 | +val_pipeline = [ |
| 142 | + dict(type='LoadImage', backend_args=backend_args), |
| 143 | + dict(type='GetBBoxCenterScale'), |
| 144 | + dict(type='TopdownAffine', input_size=codec['input_size']), |
| 145 | + dict(type='PackPoseInputs') |
| 146 | +] |
| 147 | + |
| 148 | +train_pipeline_stage2 = [ |
| 149 | + dict(type='LoadImage', backend_args=backend_args), |
| 150 | + dict(type='GetBBoxCenterScale'), |
| 151 | + dict(type='RandomFlip', direction='horizontal'), |
| 152 | + # dict(type='RandomHalfBody'), |
| 153 | + dict( |
| 154 | + type='RandomBBoxTransform', |
| 155 | + shift_factor=0., |
| 156 | + scale_factor=[0.75, 1.25], |
| 157 | + rotate_factor=60), |
| 158 | + dict(type='TopdownAffine', input_size=codec['input_size']), |
| 159 | + dict(type='mmdet.YOLOXHSVRandomAug'), |
| 160 | + dict( |
| 161 | + type='Albumentation', |
| 162 | + transforms=[ |
| 163 | + dict(type='Blur', p=0.1), |
| 164 | + dict(type='MedianBlur', p=0.1), |
| 165 | + dict( |
| 166 | + type='CoarseDropout', |
| 167 | + max_holes=1, |
| 168 | + max_height=0.4, |
| 169 | + max_width=0.4, |
| 170 | + min_holes=1, |
| 171 | + min_height=0.2, |
| 172 | + min_width=0.2, |
| 173 | + p=0.5), |
| 174 | + ]), |
| 175 | + dict(type='GenerateTarget', encoder=codec), |
| 176 | + dict(type='PackPoseInputs') |
| 177 | +] |
| 178 | + |
| 179 | +# data loaders |
| 180 | +train_dataloader = dict( |
| 181 | + batch_size=32, |
| 182 | + num_workers=10, |
| 183 | + persistent_workers=True, |
| 184 | + sampler=dict(type='DefaultSampler', shuffle=True), |
| 185 | + dataset=dict( |
| 186 | + type=dataset_type, |
| 187 | + data_root=data_root, |
| 188 | + data_mode=data_mode, |
| 189 | + ann_file='annotations/lapa_train.json', |
| 190 | + data_prefix=dict(img='train/images/'), |
| 191 | + pipeline=train_pipeline, |
| 192 | + )) |
| 193 | +val_dataloader = dict( |
| 194 | + batch_size=32, |
| 195 | + num_workers=10, |
| 196 | + persistent_workers=True, |
| 197 | + drop_last=False, |
| 198 | + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), |
| 199 | + dataset=dict( |
| 200 | + type=dataset_type, |
| 201 | + data_root=data_root, |
| 202 | + data_mode=data_mode, |
| 203 | + ann_file='annotations/lapa_val.json', |
| 204 | + data_prefix=dict(img='val/images/'), |
| 205 | + test_mode=True, |
| 206 | + pipeline=val_pipeline, |
| 207 | + )) |
| 208 | +test_dataloader = dict( |
| 209 | + batch_size=32, |
| 210 | + num_workers=10, |
| 211 | + persistent_workers=True, |
| 212 | + drop_last=False, |
| 213 | + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), |
| 214 | + dataset=dict( |
| 215 | + type=dataset_type, |
| 216 | + data_root=data_root, |
| 217 | + data_mode=data_mode, |
| 218 | + ann_file='annotations/lapa_test.json', |
| 219 | + data_prefix=dict(img='test/images/'), |
| 220 | + test_mode=True, |
| 221 | + pipeline=val_pipeline, |
| 222 | + )) |
| 223 | + |
| 224 | +# hooks |
| 225 | +default_hooks = dict( |
| 226 | + checkpoint=dict( |
| 227 | + save_best='NME', rule='less', max_keep_ckpts=1, interval=1)) |
| 228 | + |
| 229 | +custom_hooks = [ |
| 230 | + dict( |
| 231 | + type='EMAHook', |
| 232 | + ema_type='ExpMomentumEMA', |
| 233 | + momentum=0.0002, |
| 234 | + update_buffers=True, |
| 235 | + priority=49), |
| 236 | + dict( |
| 237 | + type='mmdet.PipelineSwitchHook', |
| 238 | + switch_epoch=max_epochs - stage2_num_epochs, |
| 239 | + switch_pipeline=train_pipeline_stage2) |
| 240 | +] |
| 241 | + |
| 242 | +# evaluators |
| 243 | +val_evaluator = dict( |
| 244 | + type='NME', |
| 245 | + norm_mode='keypoint_distance', |
| 246 | +) |
| 247 | +test_evaluator = val_evaluator |
0 commit comments