From e6122ad37775abb833aced83387441c9ef10e99f Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:12:53 -0700 Subject: [PATCH 01/16] add support for layer-freezing Signed-off-by: Alexandros Koumparoulis --- .../recipes/hf_auto_model_for_causal_lm.py | 7 +- .../pytorch/callbacks/layer_freezer.py | 70 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 nemo/lightning/pytorch/callbacks/layer_freezer.py diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index 46c0514b4d50..3d167d284135 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -30,6 +30,7 @@ from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing from nemo.utils.exp_manager import TimingCallback + NAME = "hf_auto_model_for_causal_lm" @@ -185,6 +186,7 @@ def finetune_recipe( trust_remote_code: bool = False, attn_implementation: str = 'sdpa', use_linear_ce_loss: bool = True, + freeze_modules: Optional[dict] = None, ) -> run.Partial: """ Create a fine-tuning recipe for a HFAutoModelForCausalLM model. @@ -215,6 +217,9 @@ def finetune_recipe( Note: This recipe uses the SQuAD dataset for fine-tuning. """ + callback = [run.Config(TimingCallback)] + if freeze_modules is not None: + callbacks.append(run.Config(LayerFreezer, freeze_modules)) recipe = run.Partial( finetune, model=model( @@ -228,7 +233,7 @@ def finetune_recipe( num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, max_steps=max_steps, - callbacks=[run.Config(TimingCallback)], + callbacks=callbacks, ), data=run.Config( SquadHFDataModule, diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py new file mode 100644 index 000000000000..948f040dc860 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Union +import math + +import pytorch_lightning as pl +from lightning.pytorch.callbacks.callback import Callback + +from nemo.lightning.io.mixin import IOMixin + + +class LayerFreezer(Callback, IOMixin): + """ + Freezes sub-modules of a LightningModule based on the list provided. The list of layers should + be the full FQN. + + Instantiate + ----------- + callback = LayerFreezer(['layer1', 'layer2',]) + trainer = pl.Trainer(callbacks=[callback], ...) + """ + + def __init__(self, frozen_layers: List[str]): + """ + Args + ---- + frozen_layers: List[str] list of layers that are frozen + """ + super().__init__() + self.frozen_layers = frozen_layers + + @staticmethod + def _resolve_attr(root, path: str): + """ + Traverse dotted attribute path (“encoder.layer1”) from root. + """ + m = root + for part in path.split('.'): + m = getattr(m, part) + return m + + def _apply_freeze(self, module, freeze: bool): + """ + Enable/disable gradients + switch (eval/train) mode. + """ + for p in module.parameters(): + p.requires_grad = not freeze + # Optional: also flip training mode so dropout / BN are disabled. + module.eval() if freeze else module.train() + + def on_train_batch_start(self, trainer, pl_module, *_): + for name in self.frozen_layers: + submod = self._resolve_attr(pl_module, name) + self._apply_freeze(submod, should_be_frozen) + self.frozen_state[name] = should_be_frozen + + # In case we resume from checkpoint, re-establish correct state + def on_train_start(self, trainer, pl_module): + self.on_train_batch_start(trainer, pl_module, None, 0) From 6e6a4b5950dbb279854af0c4bb1db0a870d7cb2d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:13:42 -0700 Subject: [PATCH 02/16] add support for layer-freezing Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index 3d167d284135..c11883a252f0 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -30,7 +30,6 @@ from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing from nemo.utils.exp_manager import TimingCallback - NAME = "hf_auto_model_for_causal_lm" From 276ddc6a3346238fe54a46a54c1d767e79219b5e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:14:20 -0700 Subject: [PATCH 03/16] add support for layer-freezing Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index c11883a252f0..4357865b9760 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -29,6 +29,7 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing from nemo.utils.exp_manager import TimingCallback +from nemo.lightning.pytorch.callbacks.layer_freezer import LayerFreezer NAME = "hf_auto_model_for_causal_lm" From 36cbe5728d77f1a8ae709dc3b6da36c864caaffa Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:15:19 -0700 Subject: [PATCH 04/16] fix Signed-off-by: Alexandros Koumparoulis --- .../pytorch/callbacks/layer_freezer.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 948f040dc860..8e8bf5d137d8 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -20,6 +20,15 @@ from nemo.lightning.io.mixin import IOMixin +def _resolve_attr(root, path: str): + """ + Traverse dotted attribute path (“encoder.layer1”) from root. + """ + m = root + for part in path.split('.'): + m = getattr(m, part) + return m + class LayerFreezer(Callback, IOMixin): """ Freezes sub-modules of a LightningModule based on the list provided. The list of layers should @@ -40,16 +49,6 @@ def __init__(self, frozen_layers: List[str]): super().__init__() self.frozen_layers = frozen_layers - @staticmethod - def _resolve_attr(root, path: str): - """ - Traverse dotted attribute path (“encoder.layer1”) from root. - """ - m = root - for part in path.split('.'): - m = getattr(m, part) - return m - def _apply_freeze(self, module, freeze: bool): """ Enable/disable gradients + switch (eval/train) mode. @@ -61,7 +60,7 @@ def _apply_freeze(self, module, freeze: bool): def on_train_batch_start(self, trainer, pl_module, *_): for name in self.frozen_layers: - submod = self._resolve_attr(pl_module, name) + submod = _resolve_attr(pl_module, name) self._apply_freeze(submod, should_be_frozen) self.frozen_state[name] = should_be_frozen From def53d1f6c406077091f831f20bda458d26cf2bd Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:15:33 -0700 Subject: [PATCH 05/16] fix Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/callbacks/layer_freezer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 8e8bf5d137d8..6c4e74412f2a 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -62,7 +62,6 @@ def on_train_batch_start(self, trainer, pl_module, *_): for name in self.frozen_layers: submod = _resolve_attr(pl_module, name) self._apply_freeze(submod, should_be_frozen) - self.frozen_state[name] = should_be_frozen # In case we resume from checkpoint, re-establish correct state def on_train_start(self, trainer, pl_module): From 52abcaff38a6c98b15c2450ed828a9f5abd971cc Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:17:49 -0700 Subject: [PATCH 06/16] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index 4357865b9760..3c64a14eee93 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -219,6 +219,8 @@ def finetune_recipe( """ callback = [run.Config(TimingCallback)] if freeze_modules is not None: + assert isinstance(freeze_modules, list), "Expected freeze_modules to be a list" + assert len(freeze_modules) > 0, "Expected freeze_modules to be non-empty" callbacks.append(run.Config(LayerFreezer, freeze_modules)) recipe = run.Partial( finetune, From fd79215087d8eba1833ed2994fc8145ac5c12dc6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:19:56 -0700 Subject: [PATCH 07/16] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index 3c64a14eee93..d1e1e59f3bcd 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -217,7 +217,7 @@ def finetune_recipe( Note: This recipe uses the SQuAD dataset for fine-tuning. """ - callback = [run.Config(TimingCallback)] + callbacks = [run.Config(TimingCallback)] if freeze_modules is not None: assert isinstance(freeze_modules, list), "Expected freeze_modules to be a list" assert len(freeze_modules) > 0, "Expected freeze_modules to be non-empty" From 9257d1250c5a4d4dca52a36a780d64b953d10850 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:20:28 -0700 Subject: [PATCH 08/16] fix Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/callbacks/layer_freezer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 6c4e74412f2a..1267585caac6 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Union -import math +from typing import List import pytorch_lightning as pl from lightning.pytorch.callbacks.callback import Callback From 8de42284be3dc6a214e9299de2805e3d89a9568f Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:25:14 -0700 Subject: [PATCH 09/16] fix Signed-off-by: Alexandros Koumparoulis --- .../pytorch/callbacks/layer_freezer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 1267585caac6..97b163157986 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import List -import pytorch_lightning as pl from lightning.pytorch.callbacks.callback import Callback from nemo.lightning.io.mixin import IOMixin @@ -58,10 +57,24 @@ def _apply_freeze(self, module, freeze: bool): module.eval() if freeze else module.train() def on_train_batch_start(self, trainer, pl_module, *_): + """ + freezes layers listed on frozen_layers + + Args: + trainer (Trainer): the trainer + pl_module (LightningModule): model + """ for name in self.frozen_layers: submod = _resolve_attr(pl_module, name) - self._apply_freeze(submod, should_be_frozen) + self._apply_freeze(submod, True) - # In case we resume from checkpoint, re-establish correct state def on_train_start(self, trainer, pl_module): + """ + on_train_start + In case we resume from checkpoint, re-establish correct state + + Args: + trainer (Trainer): the trainer + pl_module (LightningModule): model + """ self.on_train_batch_start(trainer, pl_module, None, 0) From 48126174c7da0e96ae5b6eb6f8b08822c42299b9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:40:11 -0700 Subject: [PATCH 10/16] add steps Signed-off-by: Alexandros Koumparoulis --- .../pytorch/callbacks/layer_freezer.py | 83 ++++++++++++++++--- 1 file changed, 73 insertions(+), 10 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 97b163157986..006c17f920b4 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Dict, List, Tuple, Union +import math from lightning.pytorch.callbacks.callback import Callback @@ -27,6 +28,22 @@ def _resolve_attr(root, path: str): m = getattr(m, part) return m +def make_start_end(spec: Union[int, list[int]]): + start, end = 0, 0 + # Normalize to (start, end) where end==inf means “forever” + if isinstance(spec, int): + if spec == -1: # forever + start, end = 0, math.inf + else: # first N steps + start, end = 0, spec - 1 + elif isinstance(spec, (list, tuple)) and len(spec) == 2: + start, end = spec + start = 0 if start == -1 else start + end = math.inf if end == -1 else end + else: + raise ValueError(f"Invalid schedule for '{name}': {spec}") + return start, end + class LayerFreezer(Callback, IOMixin): """ Freezes sub-modules of a LightningModule based on the list provided. The list of layers should @@ -34,18 +51,55 @@ class LayerFreezer(Callback, IOMixin): Instantiate ----------- + # to keep layers frozen for all training callback = LayerFreezer(['layer1', 'layer2',]) + # for some steps + callback = LayerFreezer({'layer1': 10, 'layer2': (10, 100)}) + trainer = pl.Trainer(callbacks=[callback], ...) """ - - def __init__(self, frozen_layers: List[str]): + def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]): """ Args ---- - frozen_layers: List[str] list of layers that are frozen + schedule: Union[list, dict] + - dict + key = attribute path of sub-module inside LightningModule + value = one of + : -1 -> frozen for entire run + : N (int>0) -> frozen for first N steps (0..N-1) + : [start, end] -> frozen if start <= step <= end + use -1 for "until end of training" + - list: + key = attribute path of sub-module inside LightningModule + value = -1 (hardcoded; use a dict if you want to specify manually). """ super().__init__() - self.frozen_layers = frozen_layers + assert isinstance(schedule, (list, dict)), type(schedule) + if isinstance(schedule, list): + schedule = { + item: -1 + for item in schedule + } + + self.schedule: Dict[str, Tuple[int, float]] = {} + self.frozen_state: Dict[str, bool] = {} # last applied state + + for name, spec in schedule.items(): + self.schedule[name] = make_start_end(spec) + + # --------------------------------------------------------------------- # + # internal helpers + # --------------------------------------------------------------------- # + @staticmethod + def _resolve_attr(root, path: str): + """ + Traverse dotted attribute path (“encoder.layer1”) from root. + """ + m = root + for part in path.split('.'): + m = getattr(m, part) + return m def _apply_freeze(self, module, freeze: bool): """ @@ -56,23 +110,32 @@ def _apply_freeze(self, module, freeze: bool): # Optional: also flip training mode so dropout / BN are disabled. module.eval() if freeze else module.train() + # --------------------------------------------------------------------- # + # Lightning hooks + # --------------------------------------------------------------------- # def on_train_batch_start(self, trainer, pl_module, *_): """ freezes layers listed on frozen_layers - Args: trainer (Trainer): the trainer pl_module (LightningModule): model """ - for name in self.frozen_layers: - submod = _resolve_attr(pl_module, name) - self._apply_freeze(submod, True) + step = trainer.global_step + + for name, (start, end) in self.schedule.items(): + should_be_frozen = (start <= step <= end) + # skip if status unchanged since last check + if self.frozen_state.get(name, None) == should_be_frozen: + continue + + submod = self._resolve_attr(pl_module, name) + self._apply_freeze(submod, should_be_frozen) + self.frozen_state[name] = should_be_frozen def on_train_start(self, trainer, pl_module): """ on_train_start In case we resume from checkpoint, re-establish correct state - Args: trainer (Trainer): the trainer pl_module (LightningModule): model From cc5448aad13faf27bc63bf0ee79355900c02d022 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 10:42:16 -0700 Subject: [PATCH 11/16] fix Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/callbacks/layer_freezer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 006c17f920b4..96e26a53a61c 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -17,6 +17,7 @@ from lightning.pytorch.callbacks.callback import Callback from nemo.lightning.io.mixin import IOMixin +ScheduleValue = Union[int, List[int], Tuple[int, int]] # -1, N, [start, end], [start, -1] def _resolve_attr(root, path: str): @@ -28,7 +29,7 @@ def _resolve_attr(root, path: str): m = getattr(m, part) return m -def make_start_end(spec: Union[int, list[int]]): +def make_start_end(name, spec: Union[int, list[int]]): start, end = 0, 0 # Normalize to (start, end) where end==inf means “forever” if isinstance(spec, int): @@ -86,7 +87,7 @@ def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]): self.frozen_state: Dict[str, bool] = {} # last applied state for name, spec in schedule.items(): - self.schedule[name] = make_start_end(spec) + self.schedule[name] = make_start_end(name, spec) # --------------------------------------------------------------------- # # internal helpers From c1382c039759fa10f4ec9c58ed982ba994d30045 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 11:23:50 -0700 Subject: [PATCH 12/16] add docstring Signed-off-by: Alexandros Koumparoulis --- .../pytorch/callbacks/layer_freezer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 96e26a53a61c..273af394b561 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -29,7 +29,23 @@ def _resolve_attr(root, path: str): m = getattr(m, part) return m -def make_start_end(name, spec: Union[int, list[int]]): +def make_start_end(name: str, spec: Union[int, list[int]]): + """Translates spec to start/end steps, for example, + spec = -1 -> (0, inf) + spec = N (int>0) -> (N, int) + spec = [start, end] -> (start, end) + + + Args: + name (str): name layer + spec (Union[int, list[int]]): spec. + + Raises: + ValueError: if spec is not int/list/tuple + + Returns: + tuple(int, int): returns start/end + """ start, end = 0, 0 # Normalize to (start, end) where end==inf means “forever” if isinstance(spec, int): From 50b032edd3cfdd765677d24b5ef78d862ff12978 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 24 Jun 2025 11:29:31 -0700 Subject: [PATCH 13/16] foix Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/callbacks/layer_freezer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 273af394b561..7d43d895ef8d 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -46,7 +46,6 @@ def make_start_end(name: str, spec: Union[int, list[int]]): Returns: tuple(int, int): returns start/end """ - start, end = 0, 0 # Normalize to (start, end) where end==inf means “forever” if isinstance(spec, int): if spec == -1: # forever @@ -55,8 +54,8 @@ def make_start_end(name: str, spec: Union[int, list[int]]): start, end = 0, spec - 1 elif isinstance(spec, (list, tuple)) and len(spec) == 2: start, end = spec - start = 0 if start == -1 else start - end = math.inf if end == -1 else end + start = max(start, 0) + end = max(end, math.inf) else: raise ValueError(f"Invalid schedule for '{name}': {spec}") return start, end From 19dab7de34697468c98d764be17ba38f6e297a2a Mon Sep 17 00:00:00 2001 From: akoumpa Date: Tue, 24 Jun 2025 18:38:02 +0000 Subject: [PATCH 14/16] Apply isort and black reformatting Signed-off-by: akoumpa --- .../llm/recipes/hf_auto_model_for_causal_lm.py | 2 +- .../pytorch/callbacks/layer_freezer.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index d1e1e59f3bcd..2289174c1823 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -28,8 +28,8 @@ from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing -from nemo.utils.exp_manager import TimingCallback from nemo.lightning.pytorch.callbacks.layer_freezer import LayerFreezer +from nemo.utils.exp_manager import TimingCallback NAME = "hf_auto_model_for_causal_lm" diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 7d43d895ef8d..2872a4a21b44 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Union import math +from typing import Dict, List, Tuple, Union from lightning.pytorch.callbacks.callback import Callback from nemo.lightning.io.mixin import IOMixin + ScheduleValue = Union[int, List[int], Tuple[int, int]] # -1, N, [start, end], [start, -1] @@ -29,6 +30,7 @@ def _resolve_attr(root, path: str): m = getattr(m, part) return m + def make_start_end(name: str, spec: Union[int, list[int]]): """Translates spec to start/end steps, for example, spec = -1 -> (0, inf) @@ -48,9 +50,9 @@ def make_start_end(name: str, spec: Union[int, list[int]]): """ # Normalize to (start, end) where end==inf means “forever” if isinstance(spec, int): - if spec == -1: # forever + if spec == -1: # forever start, end = 0, math.inf - else: # first N steps + else: # first N steps start, end = 0, spec - 1 elif isinstance(spec, (list, tuple)) and len(spec) == 2: start, end = spec @@ -60,6 +62,7 @@ def make_start_end(name: str, spec: Union[int, list[int]]): raise ValueError(f"Invalid schedule for '{name}': {spec}") return start, end + class LayerFreezer(Callback, IOMixin): """ Freezes sub-modules of a LightningModule based on the list provided. The list of layers should @@ -74,6 +77,7 @@ class LayerFreezer(Callback, IOMixin): trainer = pl.Trainer(callbacks=[callback], ...) """ + def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]): """ Args @@ -93,10 +97,7 @@ def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]): super().__init__() assert isinstance(schedule, (list, dict)), type(schedule) if isinstance(schedule, list): - schedule = { - item: -1 - for item in schedule - } + schedule = {item: -1 for item in schedule} self.schedule: Dict[str, Tuple[int, float]] = {} self.frozen_state: Dict[str, bool] = {} # last applied state @@ -139,7 +140,7 @@ def on_train_batch_start(self, trainer, pl_module, *_): step = trainer.global_step for name, (start, end) in self.schedule.items(): - should_be_frozen = (start <= step <= end) + should_be_frozen = start <= step <= end # skip if status unchanged since last check if self.frozen_state.get(name, None) == should_be_frozen: continue From efe99c49d14e16ff0f7a8103abba3ad95e72c207 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Tue, 24 Jun 2025 18:26:38 -0700 Subject: [PATCH 15/16] Update layer_freezer.py Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> --- nemo/lightning/pytorch/callbacks/layer_freezer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 2872a4a21b44..069d33e6ba51 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -57,7 +57,7 @@ def make_start_end(name: str, spec: Union[int, list[int]]): elif isinstance(spec, (list, tuple)) and len(spec) == 2: start, end = spec start = max(start, 0) - end = max(end, math.inf) + if end < 0: end = math.inf else: raise ValueError(f"Invalid schedule for '{name}': {spec}") return start, end From 2ca72dd147965ecbf85f6e89a49cd340a4ff6b3a Mon Sep 17 00:00:00 2001 From: akoumpa Date: Wed, 25 Jun 2025 01:27:26 +0000 Subject: [PATCH 16/16] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/lightning/pytorch/callbacks/layer_freezer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/layer_freezer.py b/nemo/lightning/pytorch/callbacks/layer_freezer.py index 069d33e6ba51..26e01c3df399 100644 --- a/nemo/lightning/pytorch/callbacks/layer_freezer.py +++ b/nemo/lightning/pytorch/callbacks/layer_freezer.py @@ -57,7 +57,8 @@ def make_start_end(name: str, spec: Union[int, list[int]]): elif isinstance(spec, (list, tuple)) and len(spec) == 2: start, end = spec start = max(start, 0) - if end < 0: end = math.inf + if end < 0: + end = math.inf else: raise ValueError(f"Invalid schedule for '{name}': {spec}") return start, end