Skip to content

Commit bd5176c

Browse files
authored
add linear lr warmup and lr decay scheduler (#23)
this PR adds a linear lr scheduler and includes some automation based on current best practices: a - takes user lr provided in args as lr_max, and computes final min_lr for the decay schedule based on lr / 10, per chinchilla paper. (i.e. total decay will be one order of magnitude). b - computes an automated linear warmup schedule of 10% total iters as warmup, with min warmup of 2 steps. c - computes a linear decay schedule after warmup, declining from lr_max to lr_min over the end of warmup to end of training. (per Aarons latest paper, linear is preferred schedule). d - I updated learning rate to 8e-4, in order to provide more visible per iter results to the user assuming debugModel. LR scheduling produces much improved loss curve: <img width="1052" alt="Screenshot 2024-01-28 at 6 39 34 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/667e8520-809f-419e-bfdd-c3bb8f82ff95"> I added two log prints - the warmup schedule as one line, and then a step and current lr at each iter. Both could be disabled if too much info.
1 parent 83ee9f7 commit bd5176c

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

torchtrain/lr_scheduling.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
from torch.optim.lr_scheduler import LambdaLR
5+
6+
# global states for scheduling
7+
# these are needed as LambdaLR does not support argument passing
8+
_warmup_steps = 2
9+
_decay_steps = 0
10+
11+
12+
def linear_warmup_linear_decay(current_step: int) -> float:
13+
"""Computes linear warmup followed by linear decay.
14+
Per LambdaLR requirement, this is accomplished by returning
15+
a multiplicative factor to adjust the learning rate to
16+
create the desired schedule.
17+
"""
18+
if current_step < _warmup_steps:
19+
# linear warmup
20+
# 0-indexed step, hence + 1 adjustments
21+
current_step += 1
22+
curr_adjustment = float(current_step / (_warmup_steps + 1))
23+
24+
else:
25+
# linear decay
26+
normalized_step = _decay_steps - (current_step - _warmup_steps)
27+
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps
28+
29+
return curr_adjustment
30+
31+
32+
def get_lr_scheduler(optimizer, args):
33+
"""Build a linear warmup and linear decay scheduler"""
34+
global _warmup_steps, _decay_steps
35+
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
36+
_decay_steps = float(max(1, args.steps - _warmup_steps))
37+
38+
warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
39+
return warmup_scheduler

train.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer
2323
from torchtrain.parallelisms import models_parallelize_fns
24+
from torchtrain.lr_scheduling import get_lr_scheduler
2425

2526

2627
@dataclass
@@ -46,7 +47,7 @@ def build_grad_scaler(model):
4647
# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
4748
if model.mixed_precision.param_dtype == torch.float16:
4849
enable_grad_scaling = True
49-
rank0_log(f"Enabling gradient scaling for mixed precision training.")
50+
rank0_log("Enabling gradient scaling for mixed precision training.")
5051
else:
5152
enable_grad_scaling = False
5253
rank0_log("Gradient scaling not enabled.")
@@ -85,8 +86,8 @@ def main(args):
8586
assert isinstance(model, FSDP)
8687

8788
# build optimizer after apply parallelisms to the model
88-
# TODO: add scheduler if needed
8989
optimizer = build_optimizer(model, args)
90+
scheduler = get_lr_scheduler(optimizer, args)
9091

9192
scaler = build_grad_scaler(model)
9293

@@ -144,7 +145,10 @@ def main(args):
144145
train_state.current_loss = loss.item()
145146
train_state.losses.append(train_state.current_loss)
146147

147-
rank0_log(f"current loss: {train_state.current_loss}")
148+
rank0_log(
149+
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
150+
)
151+
scheduler.step()
148152

149153

150154
if __name__ == "__main__":
@@ -171,9 +175,18 @@ def main(args):
171175
parser.add_argument(
172176
"--optimizer", type=str, default="AdamW", help="optimizer to use"
173177
)
174-
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
178+
parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use")
179+
parser.add_argument(
180+
"--warmup_pct",
181+
type=float,
182+
default=0.10,
183+
help="percentage of total training steps to use for warmup",
184+
)
175185
parser.add_argument(
176-
"--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping"
186+
"--max_norm",
187+
type=Union[float, int],
188+
default=1.0,
189+
help="max norm for gradient clipping",
177190
)
178191
parser.add_argument(
179192
"--steps", type=int, default=-1, help="how many train steps to run"

0 commit comments

Comments
 (0)