Skip to content

Commit 4812617

Browse files
committed
add steps
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 8de4228 commit 4812617

File tree

1 file changed

+73
-10
lines changed

1 file changed

+73
-10
lines changed

nemo/lightning/pytorch/callbacks/layer_freezer.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import List
14+
from typing import Dict, List, Tuple, Union
15+
import math
1516

1617
from lightning.pytorch.callbacks.callback import Callback
1718

@@ -27,25 +28,78 @@ def _resolve_attr(root, path: str):
2728
m = getattr(m, part)
2829
return m
2930

31+
def make_start_end(spec: Union[int, list[int]]):
32+
start, end = 0, 0
33+
# Normalize to (start, end) where end==inf means “forever”
34+
if isinstance(spec, int):
35+
if spec == -1: # forever
36+
start, end = 0, math.inf
37+
else: # first N steps
38+
start, end = 0, spec - 1
39+
elif isinstance(spec, (list, tuple)) and len(spec) == 2:
40+
start, end = spec
41+
start = 0 if start == -1 else start
42+
end = math.inf if end == -1 else end
43+
else:
44+
raise ValueError(f"Invalid schedule for '{name}': {spec}")
45+
return start, end
46+
3047
class LayerFreezer(Callback, IOMixin):
3148
"""
3249
Freezes sub-modules of a LightningModule based on the list provided. The list of layers should
3350
be the full FQN.
3451
3552
Instantiate
3653
-----------
54+
# to keep layers frozen for all training
3755
callback = LayerFreezer(['layer1', 'layer2',])
56+
# for some steps
57+
callback = LayerFreezer({'layer1': 10, 'layer2': (10, 100)})
58+
3859
trainer = pl.Trainer(callbacks=[callback], ...)
3960
"""
40-
41-
def __init__(self, frozen_layers: List[str]):
61+
def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]):
4262
"""
4363
Args
4464
----
45-
frozen_layers: List[str] list of layers that are frozen
65+
schedule: Union[list, dict]
66+
- dict
67+
key = attribute path of sub-module inside LightningModule
68+
value = one of
69+
: -1 -> frozen for entire run
70+
: N (int>0) -> frozen for first N steps (0..N-1)
71+
: [start, end] -> frozen if start <= step <= end
72+
use -1 for "until end of training"
73+
- list:
74+
key = attribute path of sub-module inside LightningModule
75+
value = -1 (hardcoded; use a dict if you want to specify manually).
4676
"""
4777
super().__init__()
48-
self.frozen_layers = frozen_layers
78+
assert isinstance(schedule, (list, dict)), type(schedule)
79+
if isinstance(schedule, list):
80+
schedule = {
81+
item: -1
82+
for item in schedule
83+
}
84+
85+
self.schedule: Dict[str, Tuple[int, float]] = {}
86+
self.frozen_state: Dict[str, bool] = {} # last applied state
87+
88+
for name, spec in schedule.items():
89+
self.schedule[name] = make_start_end(spec)
90+
91+
# --------------------------------------------------------------------- #
92+
# internal helpers
93+
# --------------------------------------------------------------------- #
94+
@staticmethod
95+
def _resolve_attr(root, path: str):
96+
"""
97+
Traverse dotted attribute path (“encoder.layer1”) from root.
98+
"""
99+
m = root
100+
for part in path.split('.'):
101+
m = getattr(m, part)
102+
return m
49103

50104
def _apply_freeze(self, module, freeze: bool):
51105
"""
@@ -56,23 +110,32 @@ def _apply_freeze(self, module, freeze: bool):
56110
# Optional: also flip training mode so dropout / BN are disabled.
57111
module.eval() if freeze else module.train()
58112

113+
# --------------------------------------------------------------------- #
114+
# Lightning hooks
115+
# --------------------------------------------------------------------- #
59116
def on_train_batch_start(self, trainer, pl_module, *_):
60117
"""
61118
freezes layers listed on frozen_layers
62-
63119
Args:
64120
trainer (Trainer): the trainer
65121
pl_module (LightningModule): model
66122
"""
67-
for name in self.frozen_layers:
68-
submod = _resolve_attr(pl_module, name)
69-
self._apply_freeze(submod, True)
123+
step = trainer.global_step
124+
125+
for name, (start, end) in self.schedule.items():
126+
should_be_frozen = (start <= step <= end)
127+
# skip if status unchanged since last check
128+
if self.frozen_state.get(name, None) == should_be_frozen:
129+
continue
130+
131+
submod = self._resolve_attr(pl_module, name)
132+
self._apply_freeze(submod, should_be_frozen)
133+
self.frozen_state[name] = should_be_frozen
70134

71135
def on_train_start(self, trainer, pl_module):
72136
"""
73137
on_train_start
74138
In case we resume from checkpoint, re-establish correct state
75-
76139
Args:
77140
trainer (Trainer): the trainer
78141
pl_module (LightningModule): model

0 commit comments

Comments
 (0)