Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions torchtrain/lr_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

from torch.optim.lr_scheduler import LambdaLR

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
_warmup_steps = 2
_decay_steps = 0


def linear_warmup_linear_decay(current_step: int) -> float:
"""Computes linear warmup followed by linear decay.
Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
"""
if current_step < _warmup_steps:
# linear warmup
# 0-indexed step, hence + 1 adjustments
current_step += 1
curr_adjustment = float(current_step / (_warmup_steps + 1))

else:
# linear decay
normalized_step = _decay_steps - (current_step - _warmup_steps)
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps

return curr_adjustment


def get_lr_scheduler(optimizer, args):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
_decay_steps = float(max(1, args.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
23 changes: 18 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer
from torchtrain.parallelisms import models_parallelize_fns
from torchtrain.lr_scheduling import get_lr_scheduler


@dataclass
Expand All @@ -46,7 +47,7 @@ def build_grad_scaler(model):
# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
if model.mixed_precision.param_dtype == torch.float16:
enable_grad_scaling = True
rank0_log(f"Enabling gradient scaling for mixed precision training.")
rank0_log("Enabling gradient scaling for mixed precision training.")
else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")
Expand Down Expand Up @@ -85,8 +86,8 @@ def main(args):
assert isinstance(model, FSDP)

# build optimizer after apply parallelisms to the model
# TODO: add scheduler if needed
optimizer = build_optimizer(model, args)
scheduler = get_lr_scheduler(optimizer, args)

scaler = build_grad_scaler(model)

Expand Down Expand Up @@ -144,7 +145,10 @@ def main(args):
train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)

rank0_log(f"current loss: {train_state.current_loss}")
rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
)
scheduler.step()


if __name__ == "__main__":
Expand All @@ -171,9 +175,18 @@ def main(args):
parser.add_argument(
"--optimizer", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use")
parser.add_argument(
"--warmup_pct",
type=float,
default=0.10,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
"--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping"
"--max_norm",
type=Union[float, int],
default=1.0,
help="max norm for gradient clipping",
)
parser.add_argument(
"--steps", type=int, default=-1, help="how many train steps to run"
Expand Down