Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 3dffb1b

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Make blocks attachable automatically when needed (#461)
Summary: Pull Request resolved: #461 I was frustrated by the fact that I needed to change the code for every model if I needed to be able to attach heads to it - something we need to do with most trained models. Ideally, users should be able to write their models without writing anything classy vision specific and be able to attach heads. I've made a couple of diffs which get us there, wanted to see what everyone thought about them. Please look at the second diff to see the end result. In this diff, I make the following changes - - `build_attachable_block` becomes a private function (`_build_attachable_block`), and models don't need to call it anymore. - Removed `_should_cache_output` and `set_cache_output` from `ClassyBlock` - Added a redundant `_attachable_block_names` attribute - this is needed for reading the block names inside torch script (`_attachable_blocks` is inaccessible) - T64918869 - Instead, when someone tries to attach to a module called `my_block`, we recursively search for it and wrap it by a `ClassyBlock` on the fly - Updated `get_classy_state` and `set_classy_state` to be compatible with the changes - Users can attach heads to any block which has a unique name (like `block3-2`) - `models_classy_model_test` wasn't being run internally; renamed `classy_block_test` to `models_classy_block_test` - Added additional test cases to test the changes NOTE: This breaks all checkpoints since the model definitions have changed. We still handle old checkpoints for `ResNeXt` models to allow for a smoother transition - T61141249 Differential Revision: D20714865 fbshipit-source-id: f48e768aede1c10d25754f0fad0f24ccac9a1503
1 parent 2ed4394 commit 3dffb1b

8 files changed

Lines changed: 213 additions & 103 deletions

File tree

classy_vision/models/classy_block.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,11 @@ def __init__(self, name, module):
1919
self.name = name
2020
self.output = torch.zeros(0)
2121
self._module = module
22-
self._should_cache_output = False
2322

24-
def set_cache_output(self, should_cache_output: bool = True):
25-
"""
26-
Whether to cache the output of wrapped module for head execution.
27-
"""
28-
self._should_cache_output = should_cache_output
23+
def wrapped_module(self):
24+
return self._module
2925

3026
def forward(self, input):
3127
output = self._module(input)
32-
if self._should_cache_output:
33-
self.output = output
28+
self.output = output
3429
return output

classy_vision/models/classy_model.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
from enum import Enum
9-
from typing import Any, Dict
9+
from typing import Any, Dict, List
1010

1111
import torch
1212
import torch.nn as nn
@@ -39,11 +39,13 @@ class ClassyModel(nn.Module):
3939
4040
"""
4141

42+
_attachable_block_names: List[str]
43+
4244
def __init__(self):
4345
"""Constructor for ClassyModel."""
4446
super().__init__()
45-
4647
self._attachable_blocks = {}
48+
self._attachable_block_names = []
4749
self._heads = nn.ModuleDict()
4850
self._head_outputs = {}
4951

@@ -72,20 +74,23 @@ def get_classy_state(self, deep_copy=False):
7274
7375
The returned state is used for checkpointing.
7476
77+
NOTE: For advanced users, the structure of the returned dict is -
78+
`{"model": {"trunk": trunk_state, "heads": heads_state}}`.
79+
The trunk state is the state of the model when no heads are attached.
80+
7581
Args:
7682
deep_copy: If True, creates a deep copy of the state Dict. Otherwise, the
7783
returned Dict's state will be tied to the object's.
7884
7985
Returns:
8086
A state dictionary containing the state of the model.
8187
"""
82-
# If the model doesn't have head for fine-tuning, all of model's state
83-
# live in the trunk
8488
attached_heads = self.get_heads()
85-
# clear heads to get trunk only states. There shouldn't be any component
86-
# states depend on heads
89+
# clear heads to get the state of the model without any heads, which we refer to
90+
# as the trunk state. If the model doesn't have heads attached, all of the
91+
# model's state lives in the trunk.
8792
self._clear_heads()
88-
trunk_state_dict = super().state_dict()
93+
trunk_state_dict = self.state_dict()
8994
self.set_heads(attached_heads)
9095

9196
head_state_dict = {}
@@ -124,11 +129,19 @@ def set_classy_state(self, state):
124129
125130
This is used to load the state of the model from a checkpoint.
126131
"""
132+
# load the state for heads
127133
self.load_head_states(state)
128134

129-
current_state = self.state_dict()
130-
current_state.update(state["model"]["trunk"])
131-
super().load_state_dict(current_state)
135+
# clear the heads to set the trunk's state. This is done because when heads are
136+
# attached to modules, we wrap them by ClassyBlocks, thereby changing the
137+
# structure of the model and its state dict. So, the trunk state is always
138+
# fetched / set when there are no blocks attached.
139+
attached_heads = self.get_heads()
140+
self._clear_heads()
141+
self.load_state_dict(state["model"]["trunk"])
142+
143+
# set the heads back again
144+
self.set_heads(attached_heads)
132145

133146
def forward(self, x):
134147
"""
@@ -145,27 +158,51 @@ def extract_features(self, x):
145158
"""
146159
return self.forward(x)
147160

148-
def build_attachable_block(self, name, module):
161+
def _build_attachable_block(self, name, module):
149162
"""
150163
Add a wrapper to the module to allow to attach heads to the module.
151164
"""
152165
if name in self._attachable_blocks:
153166
raise ValueError("Found duplicated block name {}".format(name))
154167
block = ClassyBlock(name, module)
155168
self._attachable_blocks[name] = block
169+
self._attachable_block_names.append(name)
156170
return block
157171

158172
@property
159173
def attachable_block_names(self):
160174
"""
161175
Return names of all attachable blocks.
162176
"""
163-
return self._attachable_blocks.keys()
177+
return self._attachable_block_names
164178

165179
def _clear_heads(self):
166180
# clear all existing heads
167181
self._heads.clear()
168182
self._head_outputs.clear()
183+
self._strip_classy_blocks(self)
184+
self._attachable_blocks = {}
185+
self._attachable_block_names = []
186+
187+
def _strip_classy_blocks(self, module):
188+
for name, child_module in module.named_children():
189+
if isinstance(child_module, ClassyBlock):
190+
module.add_module(name, child_module.wrapped_module())
191+
self._strip_classy_blocks(child_module)
192+
193+
def _make_module_attachable(self, module, module_name):
194+
found = False
195+
for name, child_module in module.named_children():
196+
if name == module_name:
197+
module.add_module(
198+
name, self._build_attachable_block(name, child_module)
199+
)
200+
found = True
201+
# do not exit - we will check all possible modules and raise an
202+
# exception if there are duplicates
203+
found_in_child = self._make_module_attachable(child_module, module_name)
204+
found = found or found_in_child
205+
return found
169206

170207
def set_heads(self, heads: Dict[str, Dict[str, ClassyHead]]):
171208
"""Attach all the heads to corresponding blocks.
@@ -190,11 +227,8 @@ def set_heads(self, heads: Dict[str, Dict[str, ClassyHead]]):
190227

191228
head_ids = set()
192229
for block_name, block_heads in heads.items():
193-
if block_name not in self._attachable_blocks:
194-
raise ValueError(
195-
"block {} does not exist or can not be attached".format(block_name)
196-
)
197-
self._attachable_blocks[block_name].set_cache_output()
230+
if not self._make_module_attachable(self, block_name):
231+
raise KeyError(f"{block_name} not found in the model")
198232
for head in block_heads.values():
199233
if head.unique_id in head_ids:
200234
raise ValueError("head id {} already exists".format(head.unique_id))

classy_vision/models/densenet.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# dependencies:
1010
import math
11+
from collections import OrderedDict
1112
from typing import Any, Dict
1213

1314
import torch
@@ -166,7 +167,7 @@ def __init__(
166167
)
167168
# loop over spatial resolutions:
168169
num_planes = init_planes
169-
blocks = []
170+
blocks = nn.Sequential()
170171
for idx, num_layers in enumerate(num_blocks):
171172
# add dense block
172173
block = self._make_dense_block(
@@ -178,18 +179,20 @@ def __init__(
178179
use_se=use_se,
179180
se_reduction_ratio=se_reduction_ratio,
180181
)
181-
blocks.append(block)
182+
blocks.add_module(f"block_{idx}", block)
182183
num_planes = num_planes + num_layers * growth_rate
183184

184185
# add transition layer:
185186
if idx != len(num_blocks) - 1:
186187
trans = _Transition(num_planes, num_planes // 2)
187-
blocks.append(self.build_attachable_block(f"transition-{idx}", trans))
188+
blocks.add_module(f"transition-{idx}", trans)
188189
num_planes = num_planes // 2
189190

190-
blocks.append(self._make_trunk_output_block(num_planes, final_bn_relu))
191+
blocks.add_module(
192+
"trunk_output", self._make_trunk_output_block(num_planes, final_bn_relu)
193+
)
191194

192-
self.features = nn.Sequential(*blocks)
195+
self.features = blocks
193196

194197
# initialize weights of convolutional and batchnorm layers:
195198
for m in self.modules():
@@ -208,7 +211,7 @@ def _make_trunk_output_block(self, num_planes, final_bn_relu):
208211
# final batch normalization:
209212
layers.add_module("norm-final", nn.BatchNorm2d(num_planes))
210213
layers.add_module("relu-final", nn.ReLU(inplace=INPLACE))
211-
return self.build_attachable_block("trunk_output", layers)
214+
return layers
212215

213216
def _make_dense_block(
214217
self,
@@ -225,21 +228,16 @@ def _make_dense_block(
225228
assert is_pos_int(expansion)
226229

227230
# create a block of dense layers at same resolution:
228-
layers = []
231+
layers = OrderedDict()
229232
for idx in range(num_layers):
230-
layers.append(
231-
self.build_attachable_block(
232-
f"block{block_idx}-{idx}",
233-
_DenseLayer(
234-
in_planes + idx * growth_rate,
235-
growth_rate=growth_rate,
236-
expansion=expansion,
237-
use_se=use_se,
238-
se_reduction_ratio=se_reduction_ratio,
239-
),
240-
)
233+
layers[f"block{block_idx}-{idx}"] = _DenseLayer(
234+
in_planes + idx * growth_rate,
235+
growth_rate=growth_rate,
236+
expansion=expansion,
237+
use_se=use_se,
238+
se_reduction_ratio=se_reduction_ratio,
241239
)
242-
return nn.Sequential(*layers)
240+
return nn.Sequential(layers)
243241

244242
@classmethod
245243
def from_config(cls, config: Dict[str, Any]) -> "DenseNet":

classy_vision/models/resnext.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import copy
1212
import math
13+
import re
14+
import warnings
15+
from collections import OrderedDict
1316
from typing import Any, Dict, List, Optional, Tuple, Union
1417

1518
import torch.nn as nn
@@ -20,6 +23,8 @@
2023
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer
2124

2225

26+
# version number for the current implementation
27+
VERSION = 0.2
2328
# global setting for in-place ReLU:
2429
INPLACE = True
2530

@@ -327,7 +332,7 @@ def __init__(
327332
use_se=use_se,
328333
se_reduction_ratio=se_reduction_ratio,
329334
)
330-
blocks.append(nn.Sequential(*new_block))
335+
blocks.append(new_block)
331336
self.blocks = nn.Sequential(*blocks)
332337

333338
self.out_planes = out_planes[-1]
@@ -371,26 +376,21 @@ def _make_resolution_block(
371376
use_se=False,
372377
se_reduction_ratio=16,
373378
):
374-
375379
# add the desired number of residual blocks:
376-
blocks = []
380+
blocks = OrderedDict()
377381
for idx in range(num_blocks):
378-
blocks.append(
379-
self.build_attachable_block(
380-
"block{}-{}".format(resolution_idx, idx),
381-
self.layer_type(
382-
in_planes if idx == 0 else out_planes,
383-
out_planes,
384-
stride=stride if idx == 0 else 1, # only first block has stride
385-
mid_planes_and_cardinality=mid_planes_and_cardinality,
386-
reduction=reduction,
387-
final_bn_relu=final_bn_relu or (idx != (num_blocks - 1)),
388-
use_se=use_se,
389-
se_reduction_ratio=se_reduction_ratio,
390-
),
391-
)
382+
block_name = "block{}-{}".format(resolution_idx, idx)
383+
blocks[block_name] = self.layer_type(
384+
in_planes if idx == 0 else out_planes,
385+
out_planes,
386+
stride=stride if idx == 0 else 1, # only first block has stride
387+
mid_planes_and_cardinality=mid_planes_and_cardinality,
388+
reduction=reduction,
389+
final_bn_relu=final_bn_relu or (idx != (num_blocks - 1)),
390+
use_se=use_se,
391+
se_reduction_ratio=se_reduction_ratio,
392392
)
393-
return blocks
393+
return nn.Sequential(blocks)
394394

395395
@classmethod
396396
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
@@ -459,6 +459,47 @@ def output_shape(self):
459459
def model_depth(self):
460460
return sum(self.num_blocks)
461461

462+
def _convert_model_state(self, state):
463+
"""Convert model state from the old implementation to the current format.
464+
465+
Updates the state dict in place and returns True if the state dict was updated.
466+
"""
467+
pattern = r"blocks\.(?P<block_id_0>[0-9])\.(?P<block_id_1>[0-9])\._module\."
468+
repl = r"blocks.\g<block_id_0>.block\g<block_id_0>-\g<block_id_1>."
469+
trunk_dict = state["model"]["trunk"]
470+
new_trunk_dict = {}
471+
replaced_keys = False
472+
for key, value in trunk_dict.items():
473+
new_key = re.sub(pattern, repl, key)
474+
if new_key != key:
475+
replaced_keys = True
476+
new_trunk_dict[new_key] = value
477+
state["model"]["trunk"] = new_trunk_dict
478+
state["version"] = VERSION
479+
return replaced_keys
480+
481+
def get_classy_state(self):
482+
state = super().get_classy_state()
483+
state["version"] = VERSION
484+
485+
def set_classy_state(self, state):
486+
version = state.get("version")
487+
if version is None:
488+
# convert the weights from the previous implementation of ResNeXt to the
489+
# current one
490+
if not self._convert_model_state(state):
491+
raise RuntimeError("ResNeXt state conversion failed")
492+
message = (
493+
"Provided state dict is from an old implementation of ResNeXt. "
494+
"This has been deprecated and will be removed soon."
495+
)
496+
warnings.warn(message, DeprecationWarning, stacklevel=2)
497+
elif version != VERSION:
498+
raise ValueError(
499+
f"Unsupported ResNeXt version: {version}. Expected: {VERSION}"
500+
)
501+
super().set_classy_state(state)
502+
462503

463504
class _ResNeXt(ResNeXt):
464505
@classmethod

classy_vision/models/resnext3d.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ def set_classy_state(self, state):
178178
# We need to support both regular checkpoint loading and 2D conv weight
179179
# inflation into 3D conv weight in this function.
180180
self.load_head_states(state)
181+
182+
# clear the heads to set the trunk state
183+
attached_heads = self.get_heads()
184+
self._clear_heads()
185+
181186
current_state = self.state_dict()
182187
for name, weight_src in state["model"]["trunk"].items():
183188
assert name in current_state, (
@@ -217,7 +222,10 @@ def set_classy_state(self, state):
217222
)
218223

219224
current_state[name] = weight_src.clone()
220-
super().load_state_dict(current_state)
225+
self.load_state_dict(current_state)
226+
227+
# set the heads back again
228+
self.set_heads(attached_heads)
221229

222230
def forward(self, x):
223231
"""
@@ -400,7 +408,6 @@ def __init__(
400408
[num_groups],
401409
skip_transformation_type,
402410
residual_transformation_type,
403-
block_callback=self.build_attachable_block,
404411
disable_pre_activation=(s == 0),
405412
final_stage=(s == (num_stages - 1)),
406413
)

0 commit comments

Comments
 (0)