diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py new file mode 100644 index 0000000000..c477e34293 --- /dev/null +++ b/torchtrain/lr_scheduling.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +from torch.optim.lr_scheduler import LambdaLR + +# global states for scheduling +# these are needed as LambdaLR does not support argument passing +_warmup_steps = 2 +_decay_steps = 0 + + +def linear_warmup_linear_decay(current_step: int) -> float: + """Computes linear warmup followed by linear decay. + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + """ + if current_step < _warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (_warmup_steps + 1)) + + else: + # linear decay + normalized_step = _decay_steps - (current_step - _warmup_steps) + curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps + + return curr_adjustment + + +def get_lr_scheduler(optimizer, args): + """Build a linear warmup and linear decay scheduler""" + global _warmup_steps, _decay_steps + _warmup_steps = max(int(args.steps * args.warmup_pct), 2) + _decay_steps = float(max(1, args.steps - _warmup_steps)) + + warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) + return warmup_scheduler diff --git a/train.py b/train.py index d5a3817e1a..6b186a154f 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ ) from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer from torchtrain.parallelisms import models_parallelize_fns +from torchtrain.lr_scheduling import get_lr_scheduler @dataclass @@ -46,7 +47,7 @@ def build_grad_scaler(model): # apply gradient scaling if mixed precision training is enabled with fp16 param dtype if model.mixed_precision.param_dtype == torch.float16: enable_grad_scaling = True - rank0_log(f"Enabling gradient scaling for mixed precision training.") + rank0_log("Enabling gradient scaling for mixed precision training.") else: enable_grad_scaling = False rank0_log("Gradient scaling not enabled.") @@ -85,8 +86,8 @@ def main(args): assert isinstance(model, FSDP) # build optimizer after apply parallelisms to the model - # TODO: add scheduler if needed optimizer = build_optimizer(model, args) + scheduler = get_lr_scheduler(optimizer, args) scaler = build_grad_scaler(model) @@ -144,7 +145,10 @@ def main(args): train_state.current_loss = loss.item() train_state.losses.append(train_state.current_loss) - rank0_log(f"current loss: {train_state.current_loss}") + rank0_log( + f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}" + ) + scheduler.step() if __name__ == "__main__": @@ -171,9 +175,18 @@ def main(args): parser.add_argument( "--optimizer", type=str, default="AdamW", help="optimizer to use" ) - parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use") + parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") + parser.add_argument( + "--warmup_pct", + type=float, + default=0.10, + help="percentage of total training steps to use for warmup", + ) parser.add_argument( - "--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping" + "--max_norm", + type=Union[float, int], + default=1.0, + help="max norm for gradient clipping", ) parser.add_argument( "--steps", type=int, default=-1, help="how many train steps to run"