Skip to content

Commit cdfe30b

Browse files
akoumpachtruong814
andauthored
[automodel] add support for layer-freezing (#14000)
* add support for layer-freezing Signed-off-by: Alexandros Koumparoulis <[email protected]> * add support for layer-freezing Signed-off-by: Alexandros Koumparoulis <[email protected]> * add support for layer-freezing Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * add steps Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * add docstring Signed-off-by: Alexandros Koumparoulis <[email protected]> * foix Signed-off-by: Alexandros Koumparoulis <[email protected]> * Apply isort and black reformatting Signed-off-by: akoumpa <[email protected]> * Update layer_freezer.py Signed-off-by: Alexandros Koumparoulis <[email protected]> * Apply isort and black reformatting Signed-off-by: akoumpa <[email protected]> --------- Signed-off-by: Alexandros Koumparoulis <[email protected]> Signed-off-by: akoumpa <[email protected]> Signed-off-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: akoumpa <[email protected]> Co-authored-by: Charlie Truong <[email protected]>
1 parent d47145a commit cdfe30b

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from nemo.collections.llm.peft.lora import LoRA
2929
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
3030
from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing
31+
from nemo.lightning.pytorch.callbacks.layer_freezer import LayerFreezer
3132
from nemo.utils.exp_manager import TimingCallback
3233

3334
NAME = "hf_auto_model_for_causal_lm"
@@ -185,6 +186,7 @@ def finetune_recipe(
185186
trust_remote_code: bool = False,
186187
attn_implementation: str = 'sdpa',
187188
use_linear_ce_loss: bool = True,
189+
freeze_modules: Optional[dict] = None,
188190
) -> run.Partial:
189191
"""
190192
Create a fine-tuning recipe for a HFAutoModelForCausalLM model.
@@ -215,6 +217,11 @@ def finetune_recipe(
215217
Note:
216218
This recipe uses the SQuAD dataset for fine-tuning.
217219
"""
220+
callbacks = [run.Config(TimingCallback)]
221+
if freeze_modules is not None:
222+
assert isinstance(freeze_modules, list), "Expected freeze_modules to be a list"
223+
assert len(freeze_modules) > 0, "Expected freeze_modules to be non-empty"
224+
callbacks.append(run.Config(LayerFreezer, freeze_modules))
218225
recipe = run.Partial(
219226
finetune,
220227
model=model(
@@ -228,7 +235,7 @@ def finetune_recipe(
228235
num_nodes=num_nodes,
229236
num_gpus_per_node=num_gpus_per_node,
230237
max_steps=max_steps,
231-
callbacks=[run.Config(TimingCallback)],
238+
callbacks=callbacks,
232239
),
233240
data=run.Config(
234241
SquadHFDataModule,
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
from typing import Dict, List, Tuple, Union
16+
17+
from lightning.pytorch.callbacks.callback import Callback
18+
19+
from nemo.lightning.io.mixin import IOMixin
20+
21+
ScheduleValue = Union[int, List[int], Tuple[int, int]] # -1, N, [start, end], [start, -1]
22+
23+
24+
def _resolve_attr(root, path: str):
25+
"""
26+
Traverse dotted attribute path (“encoder.layer1”) from root.
27+
"""
28+
m = root
29+
for part in path.split('.'):
30+
m = getattr(m, part)
31+
return m
32+
33+
34+
def make_start_end(name: str, spec: Union[int, list[int]]):
35+
"""Translates spec to start/end steps, for example,
36+
spec = -1 -> (0, inf)
37+
spec = N (int>0) -> (N, int)
38+
spec = [start, end] -> (start, end)
39+
40+
41+
Args:
42+
name (str): name layer
43+
spec (Union[int, list[int]]): spec.
44+
45+
Raises:
46+
ValueError: if spec is not int/list/tuple
47+
48+
Returns:
49+
tuple(int, int): returns start/end
50+
"""
51+
# Normalize to (start, end) where end==inf means “forever”
52+
if isinstance(spec, int):
53+
if spec == -1: # forever
54+
start, end = 0, math.inf
55+
else: # first N steps
56+
start, end = 0, spec - 1
57+
elif isinstance(spec, (list, tuple)) and len(spec) == 2:
58+
start, end = spec
59+
start = max(start, 0)
60+
if end < 0:
61+
end = math.inf
62+
else:
63+
raise ValueError(f"Invalid schedule for '{name}': {spec}")
64+
return start, end
65+
66+
67+
class LayerFreezer(Callback, IOMixin):
68+
"""
69+
Freezes sub-modules of a LightningModule based on the list provided. The list of layers should
70+
be the full FQN.
71+
72+
Instantiate
73+
-----------
74+
# to keep layers frozen for all training
75+
callback = LayerFreezer(['layer1', 'layer2',])
76+
# for some steps
77+
callback = LayerFreezer({'layer1': 10, 'layer2': (10, 100)})
78+
79+
trainer = pl.Trainer(callbacks=[callback], ...)
80+
"""
81+
82+
def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]):
83+
"""
84+
Args
85+
----
86+
schedule: Union[list, dict]
87+
- dict
88+
key = attribute path of sub-module inside LightningModule
89+
value = one of
90+
: -1 -> frozen for entire run
91+
: N (int>0) -> frozen for first N steps (0..N-1)
92+
: [start, end] -> frozen if start <= step <= end
93+
use -1 for "until end of training"
94+
- list:
95+
key = attribute path of sub-module inside LightningModule
96+
value = -1 (hardcoded; use a dict if you want to specify manually).
97+
"""
98+
super().__init__()
99+
assert isinstance(schedule, (list, dict)), type(schedule)
100+
if isinstance(schedule, list):
101+
schedule = {item: -1 for item in schedule}
102+
103+
self.schedule: Dict[str, Tuple[int, float]] = {}
104+
self.frozen_state: Dict[str, bool] = {} # last applied state
105+
106+
for name, spec in schedule.items():
107+
self.schedule[name] = make_start_end(name, spec)
108+
109+
# --------------------------------------------------------------------- #
110+
# internal helpers
111+
# --------------------------------------------------------------------- #
112+
@staticmethod
113+
def _resolve_attr(root, path: str):
114+
"""
115+
Traverse dotted attribute path (“encoder.layer1”) from root.
116+
"""
117+
m = root
118+
for part in path.split('.'):
119+
m = getattr(m, part)
120+
return m
121+
122+
def _apply_freeze(self, module, freeze: bool):
123+
"""
124+
Enable/disable gradients + switch (eval/train) mode.
125+
"""
126+
for p in module.parameters():
127+
p.requires_grad = not freeze
128+
# Optional: also flip training mode so dropout / BN are disabled.
129+
module.eval() if freeze else module.train()
130+
131+
# --------------------------------------------------------------------- #
132+
# Lightning hooks
133+
# --------------------------------------------------------------------- #
134+
def on_train_batch_start(self, trainer, pl_module, *_):
135+
"""
136+
freezes layers listed on frozen_layers
137+
Args:
138+
trainer (Trainer): the trainer
139+
pl_module (LightningModule): model
140+
"""
141+
step = trainer.global_step
142+
143+
for name, (start, end) in self.schedule.items():
144+
should_be_frozen = start <= step <= end
145+
# skip if status unchanged since last check
146+
if self.frozen_state.get(name, None) == should_be_frozen:
147+
continue
148+
149+
submod = self._resolve_attr(pl_module, name)
150+
self._apply_freeze(submod, should_be_frozen)
151+
self.frozen_state[name] = should_be_frozen
152+
153+
def on_train_start(self, trainer, pl_module):
154+
"""
155+
on_train_start
156+
In case we resume from checkpoint, re-establish correct state
157+
Args:
158+
trainer (Trainer): the trainer
159+
pl_module (LightningModule): model
160+
"""
161+
self.on_train_batch_start(trainer, pl_module, None, 0)

0 commit comments

Comments
 (0)