1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import List
14+ from typing import Dict , List , Tuple , Union
15+ import math
1516
1617from lightning .pytorch .callbacks .callback import Callback
1718
@@ -27,25 +28,78 @@ def _resolve_attr(root, path: str):
2728 m = getattr (m , part )
2829 return m
2930
31+ def make_start_end (spec : Union [int , list [int ]]):
32+ start , end = 0 , 0
33+ # Normalize to (start, end) where end==inf means “forever”
34+ if isinstance (spec , int ):
35+ if spec == - 1 : # forever
36+ start , end = 0 , math .inf
37+ else : # first N steps
38+ start , end = 0 , spec - 1
39+ elif isinstance (spec , (list , tuple )) and len (spec ) == 2 :
40+ start , end = spec
41+ start = 0 if start == - 1 else start
42+ end = math .inf if end == - 1 else end
43+ else :
44+ raise ValueError (f"Invalid schedule for '{ name } ': { spec } " )
45+ return start , end
46+
3047class LayerFreezer (Callback , IOMixin ):
3148 """
3249 Freezes sub-modules of a LightningModule based on the list provided. The list of layers should
3350 be the full FQN.
3451
3552 Instantiate
3653 -----------
54+ # to keep layers frozen for all training
3755 callback = LayerFreezer(['layer1', 'layer2',])
56+ # for some steps
57+ callback = LayerFreezer({'layer1': 10, 'layer2': (10, 100)})
58+
3859 trainer = pl.Trainer(callbacks=[callback], ...)
3960 """
40-
41- def __init__ (self , frozen_layers : List [str ]):
61+ def __init__ (self , schedule : Union [List [str ], Dict [str , ScheduleValue ]]):
4262 """
4363 Args
4464 ----
45- frozen_layers: List[str] list of layers that are frozen
65+ schedule: Union[list, dict]
66+ - dict
67+ key = attribute path of sub-module inside LightningModule
68+ value = one of
69+ : -1 -> frozen for entire run
70+ : N (int>0) -> frozen for first N steps (0..N-1)
71+ : [start, end] -> frozen if start <= step <= end
72+ use -1 for "until end of training"
73+ - list:
74+ key = attribute path of sub-module inside LightningModule
75+ value = -1 (hardcoded; use a dict if you want to specify manually).
4676 """
4777 super ().__init__ ()
48- self .frozen_layers = frozen_layers
78+ assert isinstance (schedule , (list , dict )), type (schedule )
79+ if isinstance (schedule , list ):
80+ schedule = {
81+ item : - 1
82+ for item in schedule
83+ }
84+
85+ self .schedule : Dict [str , Tuple [int , float ]] = {}
86+ self .frozen_state : Dict [str , bool ] = {} # last applied state
87+
88+ for name , spec in schedule .items ():
89+ self .schedule [name ] = make_start_end (spec )
90+
91+ # --------------------------------------------------------------------- #
92+ # internal helpers
93+ # --------------------------------------------------------------------- #
94+ @staticmethod
95+ def _resolve_attr (root , path : str ):
96+ """
97+ Traverse dotted attribute path (“encoder.layer1”) from root.
98+ """
99+ m = root
100+ for part in path .split ('.' ):
101+ m = getattr (m , part )
102+ return m
49103
50104 def _apply_freeze (self , module , freeze : bool ):
51105 """
@@ -56,23 +110,32 @@ def _apply_freeze(self, module, freeze: bool):
56110 # Optional: also flip training mode so dropout / BN are disabled.
57111 module .eval () if freeze else module .train ()
58112
113+ # --------------------------------------------------------------------- #
114+ # Lightning hooks
115+ # --------------------------------------------------------------------- #
59116 def on_train_batch_start (self , trainer , pl_module , * _ ):
60117 """
61118 freezes layers listed on frozen_layers
62-
63119 Args:
64120 trainer (Trainer): the trainer
65121 pl_module (LightningModule): model
66122 """
67- for name in self .frozen_layers :
68- submod = _resolve_attr (pl_module , name )
69- self ._apply_freeze (submod , True )
123+ step = trainer .global_step
124+
125+ for name , (start , end ) in self .schedule .items ():
126+ should_be_frozen = (start <= step <= end )
127+ # skip if status unchanged since last check
128+ if self .frozen_state .get (name , None ) == should_be_frozen :
129+ continue
130+
131+ submod = self ._resolve_attr (pl_module , name )
132+ self ._apply_freeze (submod , should_be_frozen )
133+ self .frozen_state [name ] = should_be_frozen
70134
71135 def on_train_start (self , trainer , pl_module ):
72136 """
73137 on_train_start
74138 In case we resume from checkpoint, re-establish correct state
75-
76139 Args:
77140 trainer (Trainer): the trainer
78141 pl_module (LightningModule): model
0 commit comments