-
Notifications
You must be signed in to change notification settings - Fork 9.8k
[CodeCamp2023-488] Add new configuration files for SoftTeacher algorithm in mmdetection. #10856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Shengshenlan
wants to merge
5
commits into
open-mmlab:dev-3.x
Choose a base branch
from
Shengshenlan:shengShenLan/488
base: dev-3.x
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
0af4d7a
soft_teacher_new_config
Shengshenlan a820ba9
soft_teacher_modify_format
Shengshenlan 5781684
soft_teacher_add_commentary
Shengshenlan e6116b5
removed_excess_part
Shengshenlan 8ec48ab
Merge branch 'open-mmlab:main' into shengShenLan/488
Shengshenlan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from mmengine.config import read_base | ||
|
|
||
| with read_base(): | ||
| from mmdet.datasets.transforms import * | ||
|
|
||
| from mmcv import RandomResize | ||
| from mmcv.transforms import LoadImageFromFile | ||
| from mmengine.dataset import ConcatDataset | ||
| from mmengine.dataset.sampler import DefaultSampler | ||
|
|
||
| from mmdet.datasets import (AspectRatioBatchSampler, CocoDataset, | ||
| GroupMultiSourceSampler) | ||
| from mmdet.datasets.transforms.augment_wrappers import RandAugment | ||
| from mmdet.evaluation import CocoMetric | ||
|
|
||
| # dataset settings | ||
| dataset_type = CocoDataset | ||
| data_root = 'data/coco/' | ||
|
|
||
| # Example to use different file client | ||
| # Method 1: simply set the data root and let the file I/O module | ||
| # automatically infer from prefix (not support LMDB and Memcache yet) | ||
|
|
||
| # data_root = 's3://openmmlab/datasets/detection/coco/' | ||
|
|
||
| # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 | ||
| # backend_args = dict( | ||
| # backend='petrel', | ||
| # path_mapping=dict({ | ||
| # './data/': 's3://openmmlab/datasets/detection/', | ||
| # 'data/': 's3://openmmlab/datasets/detection/' | ||
| # })) | ||
| backend_args = None | ||
|
|
||
| color_space = [ | ||
| [dict(type=ColorTransform)], | ||
| [dict(type=AutoContrast)], | ||
| [dict(type=Equalize)], | ||
| [dict(type=Sharpness)], | ||
| [dict(type=Posterize)], | ||
| [dict(type=Solarize)], | ||
| [dict(type=Color)], | ||
| [dict(type=Contrast)], | ||
| [dict(type=Brightness)], | ||
| ] | ||
|
|
||
| geometric = [ | ||
| [dict(type=Rotate)], | ||
| [dict(type=ShearX)], | ||
| [dict(type=ShearY)], | ||
| [dict(type=TranslateX)], | ||
| [dict(type=TranslateY)], | ||
| ] | ||
|
|
||
| scale = [(1333, 400), (1333, 1200)] | ||
|
|
||
| branch_field = ['sup', 'unsup_teacher', 'unsup_student'] | ||
| # pipeline used to augment labeled data, | ||
| # which will be sent to student model for supervised training. | ||
| sup_pipeline = [ | ||
| dict(type=LoadImageFromFile, backend_args=backend_args), | ||
| dict(type=LoadAnnotations, with_bbox=True), | ||
| dict( | ||
| type=RandomResize, | ||
| scale=scale, | ||
| ratio_range=(0.1, 2.0), | ||
| resize_type=Resize, | ||
| keep_ratio=True), | ||
| dict(type=RandomFlip, prob=0.5), | ||
| dict(type=RandAugment, aug_space=color_space, aug_num=1), | ||
| dict(type=FilterAnnotations, min_gt_bbox_wh=(1e-2, 1e-2)), | ||
| dict( | ||
| type=MultiBranch, | ||
| branch_field=branch_field, | ||
| sup=dict(type=PackDetInputs)) | ||
| ] | ||
|
|
||
| # pipeline used to augment unlabeled data weakly, | ||
| # which will be sent to teacher model for predicting pseudo instances. | ||
| weak_pipeline = [ | ||
| dict( | ||
| type=RandomResize, | ||
| scale=scale, | ||
| ratio_range=(0.1, 2.0), | ||
| resize_type=Resize, | ||
| keep_ratio=True), | ||
| dict(type=RandomFlip, prob=0.5), | ||
| dict( | ||
| type=PackDetInputs, | ||
| meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
| 'scale_factor', 'flip', 'flip_direction', | ||
| 'homography_matrix')), | ||
| ] | ||
|
|
||
| # pipeline used to augment unlabeled data strongly, | ||
| # which will be sent to student model for unsupervised training. | ||
| strong_pipeline = [ | ||
| dict( | ||
| type=RandomResize, | ||
| scale=scale, | ||
| ratio_range=(0.1, 2.0), | ||
| resize_type=Resize, | ||
| keep_ratio=True), | ||
| dict(type=RandomFlip, prob=0.5), | ||
| dict( | ||
| type=RandomOrder, | ||
| transforms=[ | ||
| dict(type=RandAugment, aug_space=color_space, aug_num=1), | ||
| dict(type=RandAugment, aug_space=geometric, aug_num=1), | ||
| ]), | ||
| dict(type=RandomErasing, n_patches=(1, 5), ratio=(0, 0.2)), | ||
| dict(type=FilterAnnotations, min_gt_bbox_wh=(1e-2, 1e-2)), | ||
| dict( | ||
| type=PackDetInputs, | ||
| meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
| 'scale_factor', 'flip', 'flip_direction', | ||
| 'homography_matrix')), | ||
| ] | ||
|
|
||
| # pipeline used to augment unlabeled data into different views | ||
| unsup_pipeline = [ | ||
| dict(type=LoadImageFromFile, backend_args=backend_args), | ||
| dict(type=LoadEmptyAnnotations), | ||
| dict( | ||
| type=MultiBranch, | ||
| branch_field=branch_field, | ||
| unsup_teacher=weak_pipeline, | ||
| unsup_student=strong_pipeline, | ||
| ) | ||
| ] | ||
|
|
||
| test_pipeline = [ | ||
| dict(type=LoadImageFromFile, backend_args=backend_args), | ||
| dict(type=Resize, scale=(1333, 800), keep_ratio=True), | ||
| dict( | ||
| type=PackDetInputs, | ||
| meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
| 'scale_factor')) | ||
| ] | ||
|
|
||
| batch_size = 4 | ||
| num_workers = 5 | ||
| # There are two common semi-supervised learning settings on the coco dataset: | ||
| # (1) Divide the train2017 into labeled and unlabeled datasets | ||
| # by a fixed percentage, such as 1%, 2%, 5% and 10%. | ||
| # The format of labeled_ann_file and unlabeled_ann_file are | ||
| # instances_train2017.{fold}@{percent}.json, and | ||
| # instances_train2017.{fold}@{percent}-unlabeled.json | ||
| # `fold` is used for cross-validation, and `percent` represents | ||
| # the proportion of labeled data in the train2017. | ||
| # (2) Choose the train2017 as the labeled dataset | ||
| # and unlabeled2017 as the unlabeled dataset. | ||
| # The labeled_ann_file and unlabeled_ann_file are | ||
| # instances_train2017.json and image_info_unlabeled2017.json | ||
| # We use this configuration by default. | ||
| labeled_dataset = dict( | ||
| type=dataset_type, | ||
| data_root=data_root, | ||
| ann_file='annotations/instances_train2017.json', | ||
| data_prefix=dict(img='train2017/'), | ||
| filter_cfg=dict(filter_empty_gt=True, min_size=32), | ||
| pipeline=sup_pipeline, | ||
| backend_args=backend_args) | ||
|
|
||
| unlabeled_dataset = dict( | ||
| type=dataset_type, | ||
| data_root=data_root, | ||
| ann_file='annotations/instances_unlabeled2017.json', | ||
| data_prefix=dict(img='unlabeled2017/'), | ||
| filter_cfg=dict(filter_empty_gt=False), | ||
| pipeline=unsup_pipeline, | ||
| backend_args=backend_args) | ||
|
|
||
| train_dataloader = dict( | ||
| batch_size=batch_size, | ||
| num_workers=num_workers, | ||
| persistent_workers=True, | ||
| sampler=dict( | ||
| type=GroupMultiSourceSampler, | ||
| batch_size=batch_size, | ||
| source_ratio=[1, 4]), | ||
| dataset=dict( | ||
| type=ConcatDataset, datasets=[labeled_dataset, unlabeled_dataset])) | ||
|
|
||
| val_dataloader = dict( | ||
| batch_size=1, | ||
| num_workers=2, | ||
| persistent_workers=True, | ||
| drop_last=False, | ||
| sampler=dict(type=DefaultSampler, shuffle=False), | ||
| dataset=dict( | ||
| type=dataset_type, | ||
| data_root=data_root, | ||
| ann_file='annotations/instances_val2017.json', | ||
| data_prefix=dict(img='val2017/'), | ||
| test_mode=True, | ||
| pipeline=test_pipeline, | ||
| backend_args=backend_args)) | ||
|
|
||
| test_dataloader = val_dataloader | ||
|
|
||
| val_evaluator = dict( | ||
| type=CocoMetric, | ||
| ann_file=data_root + 'annotations/instances_val2017.json', | ||
| metric='bbox', | ||
| format_only=False, | ||
| backend_args=backend_args) | ||
| test_evaluator = val_evaluator |
89 changes: 89 additions & 0 deletions
89
mmdet/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_10_coco.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from mmengine.config import read_base | ||
|
|
||
| with read_base(): | ||
| from .._base_.models.faster_rcnn_r50_fpn import * | ||
| from .._base_.datasets.semi_coco_detection import * | ||
| from .._base_.default_runtime import * | ||
|
|
||
| from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper | ||
| from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR | ||
| from mmengine.runner.loops import IterBasedTrainLoop, TestLoop | ||
| from torch.optim.sgd import SGD | ||
|
|
||
| from mmdet.engine.hooks.mean_teacher_hook import MeanTeacherHook | ||
| from mmdet.engine.runner import TeacherStudentValLoop | ||
| from mmdet.models.backbones.resnet import ResNet | ||
| from mmdet.models.data_preprocessors.data_preprocessor import ( | ||
| DetDataPreprocessor, MultiBranchDataPreprocessor) | ||
| from mmdet.models.detectors.soft_teacher import SoftTeacher | ||
|
|
||
| detector = model | ||
| detector.data_preprocessor.update( | ||
| dict( | ||
| type=DetDataPreprocessor, | ||
| mean=[103.530, 116.280, 123.675], | ||
| std=[1.0, 1.0, 1.0], | ||
| bgr_to_rgb=False, | ||
| pad_size_divisor=32)) | ||
|
|
||
| model = dict( | ||
| type=SoftTeacher, | ||
| detector=detector, | ||
| data_preprocessor=dict( | ||
| type=MultiBranchDataPreprocessor, | ||
| data_preprocessor=detector.data_preprocessor), | ||
| semi_train_cfg=dict( | ||
| freeze_teacher=True, | ||
| sup_weight=1.0, | ||
| unsup_weight=4.0, | ||
| pseudo_label_initial_score_thr=0.5, | ||
| rpn_pseudo_thr=0.9, | ||
| cls_pseudo_thr=0.9, | ||
| reg_pseudo_thr=0.02, | ||
| jitter_times=10, | ||
| jitter_scale=0.06, | ||
| min_pseudo_bbox_wh=(1e-2, 1e-2)), | ||
| semi_test_cfg=dict(predict_on='teacher')) | ||
|
|
||
| # 10% coco train2017 is set as labeled dataset | ||
| labeled_dataset = labeled_dataset | ||
| unlabeled_dataset = unlabeled_dataset | ||
| labeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| unlabeled_dataset.ann_file = 'semi_anns/' \ | ||
| '[email protected]' | ||
| unlabeled_dataset.data_prefix = dict(img='train2017/') | ||
| train_dataloader.update( | ||
| dict(dataset=dict(datasets=[labeled_dataset, unlabeled_dataset]))) | ||
|
|
||
| # training schedule for 180k | ||
| train_cfg = dict(type=IterBasedTrainLoop, max_iters=180000, val_interval=5000) | ||
| val_cfg = dict(type=TeacherStudentValLoop) | ||
| test_cfg = dict(type=TestLoop) | ||
|
|
||
| # learning rate policy | ||
| param_scheduler = [ | ||
| dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), | ||
| dict( | ||
| type=MultiStepLR, | ||
| begin=0, | ||
| end=180000, | ||
| by_epoch=False, | ||
| milestones=[120000, 160000], | ||
| gamma=0.1) | ||
| ] | ||
|
|
||
| # optimizer | ||
| # The learning rate in the old configuration was 0.01, | ||
| # but there was a loss cls error during runtime. | ||
| # Because the process of calculating losses is too close to zero, | ||
| # the learning rate is adjusted to 0.001. | ||
| optim_wrapper = dict( | ||
| type=OptimWrapper, | ||
| optimizer=dict(type=SGD, lr=0.001, momentum=0.9, weight_decay=0.0001)) | ||
|
|
||
| default_hooks.update( | ||
| dict(checkpoint=dict(by_epoch=False, interval=10000, max_keep_ckpts=2))) | ||
| log_processor.update(dict(by_epoch=False)) | ||
|
|
||
| custom_hooks = [dict(type=MeanTeacherHook)] | ||
13 changes: 13 additions & 0 deletions
13
mmdet/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_1_coco.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from mmengine.config import read_base | ||
|
|
||
| with read_base(): | ||
| from .soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_10_coco import * | ||
|
|
||
| # 1% coco train2017 is set as labeled dataset | ||
| labeled_dataset = labeled_dataset | ||
| unlabeled_dataset = unlabeled_dataset | ||
| labeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| unlabeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| train_dataloader.update( | ||
| dict(dataset=dict(datasets=[labeled_dataset, unlabeled_dataset]))) |
13 changes: 13 additions & 0 deletions
13
mmdet/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_2_coco.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from mmengine.config import read_base | ||
|
|
||
| with read_base(): | ||
| from .soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_10_coco import * | ||
|
|
||
| # 1% coco train2017 is set as labeled dataset | ||
| labeled_dataset = labeled_dataset | ||
| unlabeled_dataset = unlabeled_dataset | ||
| labeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| unlabeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| train_dataloader.update( | ||
| dict(dataset=dict(datasets=[labeled_dataset, unlabeled_dataset]))) |
13 changes: 13 additions & 0 deletions
13
mmdet/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_5_coco.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from mmengine.config import read_base | ||
|
|
||
| with read_base(): | ||
| from .soft_teacher_faster_rcnn_r50_caffe_fpn_180k_semi_10_coco import * | ||
|
|
||
| # 1% coco train2017 is set as labeled dataset | ||
| labeled_dataset = labeled_dataset | ||
| unlabeled_dataset = unlabeled_dataset | ||
| labeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| unlabeled_dataset.ann_file = 'semi_anns/[email protected]' | ||
| train_dataloader.update( | ||
| dict(dataset=dict(datasets=[labeled_dataset, unlabeled_dataset]))) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.