Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions configs/selfsup/_base_/datasets/coco_orl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import copy
# dataset settings
dataset_type = 'mmdet.CocoDataset'
# data_root = 'data/coco/'
data_root = '../data/coco/'
file_client_args = dict(backend='disk')


view_pipeline = [
dict(
type='RandomResizedCrop',
size=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1),
dict(type='RandomSolarize', prob=0)
]




view_pipeline1 = copy.deepcopy(view_pipeline)
view_pipeline2 = copy.deepcopy(view_pipeline)
view_pipeline2[4]['prob'] = 0.1 # gaussian blur
view_pipeline2[5]['prob'] = 0.2 # solarization
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='MultiView',
num_views=[1, 1],
transforms=[view_pipeline1, view_pipeline2]),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]


train_dataloader = dict(
batch_size=64,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline))

custom_hooks = [
dict(
type='ExtractorHook',
extract_dataloader=dict(
batch_size=256,
num_workers=6,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline)),
normalize=True
),

]
55 changes: 55 additions & 0 deletions configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@

_base_ = [
'../../_base_/models/byol.py',
'../../_base_/datasets/coco_orl.py',
'../../_base_/schedules/sgd_coslr-200e_in1k.py',
'../../_base_/default_runtime.py',
]

# model settings
model = dict(
neck=dict(
type='NonLinearNeck',
in_channels=2048,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=True),
head=dict(type='LatentPredictHead',
predictor=dict(type='NonLinearNeck',
in_channels=256, hid_channels=4096,
out_channels=256, num_layers=2,
with_bias=False, with_last_bn=False, with_avg_pool=False)))


update_interval = 1 # interval for accumulate gradient
# Amp optimizer
optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
accumulative_counts=update_interval,
)
warmup_epochs=4
total_epochs = 800
# learning policy
param_scheduler = [
# warmup
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
end=warmup_epochs,
# Update the learning rate after every iters.
convert_to_iter_based=True),
# ConsineAnnealingLR/StepLR/..
dict(type='CosineAnnealingLR',eta_min=0., T_max=total_epochs, by_epoch=True, begin=warmup_epochs, end=total_epochs)
]



# runtime settings
default_hooks = dict(checkpoint=dict(interval=100))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
_base_ = [
'../../_base_/models/byol.py',
'../../_base_/datasets/coco_orl.py',
'../../_base_/schedules/sgd_coslr-200e_in1k.py',
'../../_base_/default_runtime.py',
]
# model settings
model = dict(
neck=dict(
type='NonLinearNeck',
in_channels=2048,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=True),
head=dict(type='LatentPredictHead',
predictor=dict(type='NonLinearNeck',
in_channels=256, hid_channels=4096,
out_channels=256, num_layers=2,
with_bias=False, with_last_bn=False, with_avg_pool=False)))


update_interval = 1 # interval for accumulate gradient
# Amp optimizer
optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
accumulative_counts=update_interval,
)
warmup_epochs=4
total_epochs = 800
# learning policy
param_scheduler = [
# warmup
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
end=warmup_epochs,
# Update the learning rate after every iters.
convert_to_iter_based=True),
# ConsineAnnealingLR/StepLR/..
dict(type='CosineAnnealingLR',eta_min=0., T_max=total_epochs, by_epoch=True, begin=warmup_epochs, end=total_epochs)
]


# runtime settings
default_hooks = dict(checkpoint=dict(interval=100))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)


custom_hooks = [
dict(
type='ExtractorHook',
extract_dataloader=dict(
batch_size=256,
num_workers=6,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
collate_fn=dict(type='default_collate'),
dataset=dict(
type={{_base_.dataset_type}},
data_root={{_base_.data_root}},
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline={{_base_.train_pipeline}})),
normalize=True
),
]
69 changes: 69 additions & 0 deletions mmselfsup/engine/hooks/extractor_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
from mmengine.dist import is_distributed
from mmengine.hooks import Hook
from mmengine.logging import print_log

from mmselfsup.models.utils import Extractor
from mmselfsup.registry import HOOKS
from mmselfsup.utils import clustering as _clustering


@HOOKS.register_module()
class ExtractorHook(Hook):
"""feature extractor hook.

This hook includes the global clustering process in DC.

Args:
extractor (dict): Config dict for feature extraction.
clustering (dict): Config dict that specifies the clustering algorithm.
unif_sampling (bool): Whether to apply uniform sampling.
reweight (bool): Whether to apply loss re-weighting.
reweight_pow (float): The power of re-weighting.
init_memory (bool): Whether to initialize memory banks used in ODC.
Defaults to False.
initial (bool): Whether to call the hook initially. Defaults to True.
interval (int): Frequency of epochs to call the hook. Defaults to 1.
seed (int, optional): Random seed. Defaults to None.
"""

def __init__(
self,
extract_dataloader: dict,
normalize=True,
seed: Optional[int] = None) -> None:

self.dist_mode = is_distributed()
self.extractor = Extractor(
extract_dataloader=extract_dataloader,
seed=seed,
dist_mode=self.dist_mode,
pool_cfg=None)
self.normalize=normalize


def before_run(self, runner):
self._extract_func(runner)

def _extract_func(self, runner):
# step 1: get features
runner.model.eval()
features = self.extractor(runner.model.module)['feat']
if self.normalize:
features = nn.functional.normalize(torch.from_numpy(features), dim=1)
# step 2: save features
if not self.dist_mode or (self.dist_mode and runner.rank == 0):
np.save(
"{}/feature_epoch_{}.npy".format(runner.work_dir,
runner.epoch),
features.numpy())
print_log(
"Feature extraction done!!! total features: {}\tfeature dimension: {}".format(
features.size(0), features.size(1)),
logger='current')