Skip to content

Commit a8a64ec

Browse files
user074ebsmothers
andauthored
Add LR Scheduler to single device full finetune (#1350)
Co-authored-by: ebsmothers <ebs@meta.com>
1 parent 7af77c7 commit a8a64ec

5 files changed

Lines changed: 125 additions & 3 deletions

File tree

recipes/configs/llama2/7B_full_low_memory.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ batch_size: 2
5353
epochs: 3
5454
optimizer:
5555
_component_: bitsandbytes.optim.PagedAdamW
56-
lr: 2e-5
56+
lr: 1e-5
57+
lr_scheduler:
58+
_component_: torchtune.modules.get_cosine_schedule_with_warmup
59+
num_warmup_steps: 100
5760
optimizer_in_bwd: True
5861
loss:
5962
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

recipes/configs/llama3/8B_full_single_device.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ batch_size: 2
5252
epochs: 3
5353
optimizer:
5454
_component_: bitsandbytes.optim.PagedAdamW8bit
55-
lr: 2e-5
55+
lr: 1e-5
56+
lr_scheduler:
57+
_component_: torchtune.modules.get_cosine_schedule_with_warmup
58+
num_warmup_steps: 100
5659
loss:
5760
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
5861
max_steps_per_epoch: null

recipes/full_finetune_single_device.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,13 @@ def setup(self, cfg: DictConfig) -> None:
267267
self._steps_per_epoch = self.max_steps_per_epoch
268268
self.global_step = self.epochs_run * self._steps_per_epoch
269269

270+
# Setup lr scheduler
271+
self._lr_scheduler = self._setup_lr_scheduler(
272+
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
273+
num_training_steps=self.total_epochs * self._steps_per_epoch,
274+
last_epoch=self.global_step - 1,
275+
)
276+
270277
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
271278
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
272279
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
@@ -422,6 +429,53 @@ def _setup_optimizer(
422429
log.info("Optimizer is initialized.")
423430
return optimizer
424431

432+
def _setup_lr_scheduler(
433+
self,
434+
cfg_lr_scheduler: Optional[DictConfig],
435+
num_training_steps: int,
436+
last_epoch: int,
437+
) -> Optional[Optimizer]:
438+
"""
439+
Set up the learning rate scheduler based on the provided configuration.
440+
It handles both standard optimization and optimizer-in-backward cases, and supports
441+
schedulers from both torchtune.modules and torch.optim.
442+
443+
Args:
444+
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
445+
num_training_steps (int): The total number of training steps.
446+
last_epoch (int): The index of the last epoch.
447+
448+
Returns:
449+
lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
450+
"""
451+
if cfg_lr_scheduler is None:
452+
log.info(
453+
"No learning rate scheduler configured. Using constant learning rate."
454+
)
455+
return None
456+
457+
if self._optimizer_in_bwd:
458+
# Use the first optimizer from the wrapper to represent the learning rate
459+
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
460+
else:
461+
# Standard case: use the single optimizer
462+
optimizer = self._optimizer
463+
464+
# Instantiate the learning rate scheduler
465+
lr_scheduler = config.instantiate(
466+
cfg_lr_scheduler,
467+
optimizer,
468+
num_training_steps=num_training_steps,
469+
last_epoch=last_epoch,
470+
)
471+
472+
if self._optimizer_in_bwd:
473+
# Modify the scheduler for optimizer_in_bwd case
474+
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)
475+
476+
log.info("Learning rate scheduler is initialized.")
477+
return lr_scheduler
478+
425479
def _setup_data(
426480
self,
427481
cfg_dataset: DictConfig,
@@ -586,6 +640,9 @@ def train(self) -> None:
586640
self._optimizer.step()
587641
self._optimizer.zero_grad(set_to_none=True)
588642

643+
# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
644+
if self._lr_scheduler is not None:
645+
self._lr_scheduler.step()
589646
self.global_step += 1
590647

591648
loss_to_log = running_loss.item()

tests/recipes/test_full_finetune_single_device.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def _get_test_config_overrides(self):
4545
"max_steps_per_epoch=2",
4646
"optimizer=torch.optim.AdamW",
4747
"optimizer.lr=2e-5",
48+
"lr_scheduler.num_warmup_steps=0",
49+
"lr_scheduler.num_cycles=0",
4850
"log_every_n_steps=1",
4951
"clip_grad_norm=100",
5052
] + dummy_alpaca_dataset_config()

torchtune/training/memory.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from typing import Any, Callable, Dict, Set, Type, Union
1111

1212
import torch
13-
1413
from torch import nn
1514
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1615
apply_activation_checkpointing,
1716
)
1817
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
18+
from torch.optim.lr_scheduler import LRScheduler
1919
from torchtune.utils import get_logger
2020

2121
_log: logging.Logger = get_logger()
@@ -91,6 +91,7 @@ class OptimizerInBackwardWrapper:
9191

9292
def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]):
9393
self.optim_map = optim_map
94+
self.lr_scheduler = None
9495

9596
def state_dict(self) -> Dict[str, Any]:
9697
"""
@@ -136,6 +137,62 @@ def get_optim_key(self, key: str) -> Any:
136137
"""
137138
return list(self.optim_map.values())[0].param_groups[0][key]
138139

140+
def set_lr_scheduler(self, lr_scheduler: LRScheduler) -> None:
141+
"""
142+
Sets the learning rate scheduler and modifies its step method to update all optimizers.
143+
144+
Args:
145+
lr_scheduler (LRScheduler): The learning rate scheduler to use.
146+
"""
147+
self.lr_scheduler = lr_scheduler
148+
original_step = self.lr_scheduler.step
149+
150+
def custom_step(epoch=None):
151+
if epoch is None:
152+
original_step()
153+
else:
154+
original_step(epoch)
155+
new_lr = self.lr_scheduler.get_last_lr()[0]
156+
for opt in self.optim_map.values():
157+
for param_group in opt.param_groups:
158+
param_group["lr"] = new_lr
159+
160+
self.lr_scheduler.step = custom_step
161+
162+
def step_lr_scheduler(self, epoch: int = None):
163+
"""
164+
Steps the learning rate scheduler if it exists.
165+
166+
Args:
167+
epoch (int, optional): The current epoch number. Defaults to None.
168+
169+
Raises:
170+
RuntimeError: If the LR scheduler has not been set.
171+
"""
172+
if self.lr_scheduler:
173+
self.lr_scheduler.step(epoch)
174+
else:
175+
raise RuntimeError(
176+
"LR scheduler has not been set. Call set_lr_scheduler first."
177+
)
178+
179+
def get_last_lr(self) -> float:
180+
"""
181+
Gets the last learning rate from the scheduler if it exists.
182+
183+
Returns:
184+
float: The last learning rate.
185+
186+
Raises:
187+
RuntimeError: If the LR scheduler has not been set.
188+
"""
189+
if self.lr_scheduler:
190+
return self.lr_scheduler.get_last_lr()[0]
191+
else:
192+
raise RuntimeError(
193+
"LR scheduler has not been set. Call set_lr_scheduler first."
194+
)
195+
139196

140197
def create_optim_in_bwd_wrapper(
141198
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer]

0 commit comments

Comments
 (0)