@@ -267,6 +267,13 @@ def setup(self, cfg: DictConfig) -> None:
267267 self ._steps_per_epoch = self .max_steps_per_epoch
268268 self .global_step = self .epochs_run * self ._steps_per_epoch
269269
270+ # Setup lr scheduler
271+ self ._lr_scheduler = self ._setup_lr_scheduler (
272+ cfg_lr_scheduler = cfg .get ("lr_scheduler" , None ),
273+ num_training_steps = self .total_epochs * self ._steps_per_epoch ,
274+ last_epoch = self .global_step - 1 ,
275+ )
276+
270277 # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
271278 # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
272279 self ._profiler = self ._setup_profiler (cfg .get (PROFILER_KEY , None ))
@@ -422,6 +429,53 @@ def _setup_optimizer(
422429 log .info ("Optimizer is initialized." )
423430 return optimizer
424431
432+ def _setup_lr_scheduler (
433+ self ,
434+ cfg_lr_scheduler : Optional [DictConfig ],
435+ num_training_steps : int ,
436+ last_epoch : int ,
437+ ) -> Optional [Optimizer ]:
438+ """
439+ Set up the learning rate scheduler based on the provided configuration.
440+ It handles both standard optimization and optimizer-in-backward cases, and supports
441+ schedulers from both torchtune.modules and torch.optim.
442+
443+ Args:
444+ cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
445+ num_training_steps (int): The total number of training steps.
446+ last_epoch (int): The index of the last epoch.
447+
448+ Returns:
449+ lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
450+ """
451+ if cfg_lr_scheduler is None :
452+ log .info (
453+ "No learning rate scheduler configured. Using constant learning rate."
454+ )
455+ return None
456+
457+ if self ._optimizer_in_bwd :
458+ # Use the first optimizer from the wrapper to represent the learning rate
459+ optimizer = next (iter (self ._optim_ckpt_wrapper .optim_map .values ()))
460+ else :
461+ # Standard case: use the single optimizer
462+ optimizer = self ._optimizer
463+
464+ # Instantiate the learning rate scheduler
465+ lr_scheduler = config .instantiate (
466+ cfg_lr_scheduler ,
467+ optimizer ,
468+ num_training_steps = num_training_steps ,
469+ last_epoch = last_epoch ,
470+ )
471+
472+ if self ._optimizer_in_bwd :
473+ # Modify the scheduler for optimizer_in_bwd case
474+ self ._optim_ckpt_wrapper .set_lr_scheduler (lr_scheduler )
475+
476+ log .info ("Learning rate scheduler is initialized." )
477+ return lr_scheduler
478+
425479 def _setup_data (
426480 self ,
427481 cfg_dataset : DictConfig ,
@@ -586,6 +640,9 @@ def train(self) -> None:
586640 self ._optimizer .step ()
587641 self ._optimizer .zero_grad (set_to_none = True )
588642
643+ # Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
644+ if self ._lr_scheduler is not None :
645+ self ._lr_scheduler .step ()
589646 self .global_step += 1
590647
591648 loss_to_log = running_loss .item ()
0 commit comments