From 713ccd2a851d53d78299c43d5c2a6955c9181e89 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 28 Jan 2024 18:40:17 -0800 Subject: [PATCH 1/5] add linear warmup and decay scheduler --- run_llama_train.sh | 2 +- torchtrain/lr_scheduling.py | 33 +++++++++++++++++++++++++++++++++ train.py | 6 +++++- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 torchtrain/lr_scheduling.py diff --git a/run_llama_train.sh b/run_llama_train.sh index 100b52944b..30ff3c471f 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -5,7 +5,7 @@ set -ex TRAINER_DIR=${1:-/home/$USER/local/torchtrain} MODEL="debugmodel" -NGPU=8 +NGPU=2 MP=4 torchrun --nproc_per_node=${NGPU} \ diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py new file mode 100644 index 0000000000..986ba78fce --- /dev/null +++ b/torchtrain/lr_scheduling.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import contextlib +import os +import torch +from dataclasses import dataclass, field +from torchtrain.logging_utils import rank0_log + +class LinearScheduler: + def __init__(self, args): + self.lr_max = args.lr + self.lr_min = args.lr / 10 + self.lr_warmup_pct = 0.10 + # enforce min of 2 steps for warmup + self.warmup_steps = max(int(args.steps * self.lr_warmup_pct), 2) + + rank0_log(f"LR Warmup Schedule: {self.lr_min} -> {self.lr_max} with {self.warmup_steps} warmup steps") + self.decay_steps = args.steps - self.warmup_steps + self.curr_lr = 0 + + def set_lr(self, optimizer, step): + """ Set the learning rate for the optimizer """ + if step < self.warmup_steps: + self.curr_lr = self.lr_max * (step / self.warmup_steps) + else: + self.curr_lr = self.lr_min + ((self.lr_max - self.lr_min) * ( + 1 - (step - self.warmup_steps) / self.decay_steps + )) + # apply across all optim groups + for param_group in optimizer.param_groups: + param_group["lr"] = self.curr_lr + rank0_log(f"Optimizer LR Update: {step=}, lr = {round(self.curr_lr,6)}") diff --git a/train.py b/train.py index 073e5b1755..9f684022a0 100644 --- a/train.py +++ b/train.py @@ -69,6 +69,9 @@ def main(args): # build optimizer after apply parallelisms to the model # TODO: add scheduler if needed + from torchtrain.lr_scheduling import LinearScheduler + scheduler = LinearScheduler(args) + optimizer = build_optimizer(model, args) # TODO: add metrics @@ -88,6 +91,7 @@ def main(args): with maybe_run_profiler() as torch_profiler: while train_state.step < args.steps or args.steps == -1: train_state.step += 1 + scheduler.set_lr(optimizer, train_state.step) # get batch batch = next(iter(data_loader)) input_ids, labels = batch @@ -143,7 +147,7 @@ def main(args): parser.add_argument( "--optimizer", type=str, default="AdamW", help="optimizer to use" ) - parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use") + parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") parser.add_argument( "--steps", type=int, default=-1, help="how many train steps to run" ) From 9c23af2f0fa11c52c67becfdc5088ff67f660f6a Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 28 Jan 2024 18:43:15 -0800 Subject: [PATCH 2/5] ruff check and formatting --- run_llama_train.sh | 2 +- torchtrain/lr_scheduling.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index 30ff3c471f..100b52944b 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -5,7 +5,7 @@ set -ex TRAINER_DIR=${1:-/home/$USER/local/torchtrain} MODEL="debugmodel" -NGPU=2 +NGPU=8 MP=4 torchrun --nproc_per_node=${NGPU} \ diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py index 986ba78fce..24d36cdce4 100644 --- a/torchtrain/lr_scheduling.py +++ b/torchtrain/lr_scheduling.py @@ -1,12 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -import contextlib -import os -import torch -from dataclasses import dataclass, field from torchtrain.logging_utils import rank0_log + class LinearScheduler: def __init__(self, args): self.lr_max = args.lr @@ -15,18 +12,21 @@ def __init__(self, args): # enforce min of 2 steps for warmup self.warmup_steps = max(int(args.steps * self.lr_warmup_pct), 2) - rank0_log(f"LR Warmup Schedule: {self.lr_min} -> {self.lr_max} with {self.warmup_steps} warmup steps") + rank0_log( + f"LR Warmup Schedule: {self.lr_min} -> {self.lr_max} with {self.warmup_steps} warmup steps" + ) self.decay_steps = args.steps - self.warmup_steps self.curr_lr = 0 def set_lr(self, optimizer, step): - """ Set the learning rate for the optimizer """ + """Set the learning rate for the optimizer""" if step < self.warmup_steps: self.curr_lr = self.lr_max * (step / self.warmup_steps) else: - self.curr_lr = self.lr_min + ((self.lr_max - self.lr_min) * ( - 1 - (step - self.warmup_steps) / self.decay_steps - )) + self.curr_lr = self.lr_min + ( + (self.lr_max - self.lr_min) + * (1 - (step - self.warmup_steps) / self.decay_steps) + ) # apply across all optim groups for param_group in optimizer.param_groups: param_group["lr"] = self.curr_lr From 39cb3483928b3336a7deebca1fd096c376f2d8cb Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 30 Jan 2024 11:16:47 -0800 Subject: [PATCH 3/5] update lr scheduler to lambdaLR with full linear up and down --- torchtrain/lr_scheduling.py | 66 ++++++++++++++++++++----------------- train.py | 15 +++++---- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py index 24d36cdce4..efc8dbaa49 100644 --- a/torchtrain/lr_scheduling.py +++ b/torchtrain/lr_scheduling.py @@ -1,33 +1,39 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -from torchtrain.logging_utils import rank0_log - - -class LinearScheduler: - def __init__(self, args): - self.lr_max = args.lr - self.lr_min = args.lr / 10 - self.lr_warmup_pct = 0.10 - # enforce min of 2 steps for warmup - self.warmup_steps = max(int(args.steps * self.lr_warmup_pct), 2) - - rank0_log( - f"LR Warmup Schedule: {self.lr_min} -> {self.lr_max} with {self.warmup_steps} warmup steps" - ) - self.decay_steps = args.steps - self.warmup_steps - self.curr_lr = 0 - - def set_lr(self, optimizer, step): - """Set the learning rate for the optimizer""" - if step < self.warmup_steps: - self.curr_lr = self.lr_max * (step / self.warmup_steps) - else: - self.curr_lr = self.lr_min + ( - (self.lr_max - self.lr_min) - * (1 - (step - self.warmup_steps) / self.decay_steps) - ) - # apply across all optim groups - for param_group in optimizer.param_groups: - param_group["lr"] = self.curr_lr - rank0_log(f"Optimizer LR Update: {step=}, lr = {round(self.curr_lr,6)}") +from torch.optim.lr_scheduler import LambdaLR + +# global states for scheduling +# these are needed as LambdaLR does not support argument passing +_warmup_steps = 2 +_decay_steps = 0 + + +def linear_warmup_linear_decay(current_step: int) -> float: + """Computes linear warmup followed by linear decay. + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + """ + if current_step < _warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (_warmup_steps + 1)) + + else: + # linear decay + normalized_step = _decay_steps - (current_step - _warmup_steps) + curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps + + return curr_adjustment + + +def get_full_lr_scheduler(optimizer, args): + """Build a linear warmup and linear decay scheduler""" + global _warmup_steps, _decay_steps + _warmup_steps = max(int(args.steps * args.warmup_pct), 2) + _decay_steps = float(max(1, args.steps - _warmup_steps)) + + warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) + return warmup_scheduler diff --git a/train.py b/train.py index 9f684022a0..c5fea71d6c 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ ) from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer from torchtrain.parallelisms import models_parallelize_fns +from torchtrain.lr_scheduling import get_full_lr_scheduler @dataclass @@ -68,11 +69,8 @@ def main(args): model = models_parallelize_fns[model_name](model, args) # build optimizer after apply parallelisms to the model - # TODO: add scheduler if needed - from torchtrain.lr_scheduling import LinearScheduler - scheduler = LinearScheduler(args) - optimizer = build_optimizer(model, args) + scheduler = get_full_lr_scheduler(optimizer, args) # TODO: add metrics @@ -91,7 +89,6 @@ def main(args): with maybe_run_profiler() as torch_profiler: while train_state.step < args.steps or args.steps == -1: train_state.step += 1 - scheduler.set_lr(optimizer, train_state.step) # get batch batch = next(iter(data_loader)) input_ids, labels = batch @@ -120,7 +117,10 @@ def main(args): train_state.current_loss = loss.item() train_state.losses.append(train_state.current_loss) - rank0_log(f"current loss: {train_state.current_loss}") + rank0_log( + f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}" + ) + scheduler.step() if __name__ == "__main__": @@ -148,6 +148,9 @@ def main(args): "--optimizer", type=str, default="AdamW", help="optimizer to use" ) parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") + parser.add_argument( + "--warmup_pct", type=float, default=0.10, help="pct training to use for warmup" + ) parser.add_argument( "--steps", type=int, default=-1, help="how many train steps to run" ) From f42fa47e777ed947629cd0e5fde47df3a811354d Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 1 Feb 2024 11:40:54 -0800 Subject: [PATCH 4/5] update with pr feedback --- torchtrain/lr_scheduling.py | 2 +- train.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py index efc8dbaa49..c477e34293 100644 --- a/torchtrain/lr_scheduling.py +++ b/torchtrain/lr_scheduling.py @@ -29,7 +29,7 @@ def linear_warmup_linear_decay(current_step: int) -> float: return curr_adjustment -def get_full_lr_scheduler(optimizer, args): +def get_lr_scheduler(optimizer, args): """Build a linear warmup and linear decay scheduler""" global _warmup_steps, _decay_steps _warmup_steps = max(int(args.steps * args.warmup_pct), 2) diff --git a/train.py b/train.py index a12a9da421..ff2480b8e4 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ ) from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer from torchtrain.parallelisms import models_parallelize_fns -from torchtrain.lr_scheduling import get_full_lr_scheduler +from torchtrain.lr_scheduling import get_lr_scheduler @dataclass @@ -87,7 +87,7 @@ def main(args): # build optimizer after apply parallelisms to the model optimizer = build_optimizer(model, args) - scheduler = get_full_lr_scheduler(optimizer, args) + scheduler = get_lr_scheduler(optimizer, args) scaler = build_grad_scaler(model) @@ -177,7 +177,7 @@ def main(args): ) parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") parser.add_argument( - "--warmup_pct", type=float, default=0.10, help="pct training to use for warmup" + "--warmup_pct", type=float, default=0.10, help="percentage of total training steps to use for warmup" ) parser.add_argument( "--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping" @@ -194,7 +194,7 @@ def main(args): parser.add_argument( "--sp_degree", type=int, - default=LOCAL_WORLD_SIZE, + default=1, help="Sequence Parallelism degree. 1 means disabled.", ) parser.add_argument( From 7f088e707e6792f3552245a2fdb6b982c30f7501 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 1 Feb 2024 11:47:11 -0800 Subject: [PATCH 5/5] update with pr feedback, ruff format --- train.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index ff2480b8e4..6b186a154f 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ def build_grad_scaler(model): # apply gradient scaling if mixed precision training is enabled with fp16 param dtype if model.mixed_precision.param_dtype == torch.float16: enable_grad_scaling = True - rank0_log(f"Enabling gradient scaling for mixed precision training.") + rank0_log("Enabling gradient scaling for mixed precision training.") else: enable_grad_scaling = False rank0_log("Gradient scaling not enabled.") @@ -177,10 +177,16 @@ def main(args): ) parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") parser.add_argument( - "--warmup_pct", type=float, default=0.10, help="percentage of total training steps to use for warmup" + "--warmup_pct", + type=float, + default=0.10, + help="percentage of total training steps to use for warmup", ) parser.add_argument( - "--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping" + "--max_norm", + type=Union[float, int], + default=1.0, + help="max norm for gradient clipping", ) parser.add_argument( "--steps", type=int, default=-1, help="how many train steps to run"