diff --git a/library/adafactor_fused.py b/library/optimizer_fused.py similarity index 54% rename from library/adafactor_fused.py rename to library/optimizer_fused.py index bdfc32ced..83fe9203e 100644 --- a/library/adafactor_fused.py +++ b/library/optimizer_fused.py @@ -1,6 +1,6 @@ import math import torch -from transformers import Adafactor +from transformers import Adafactor, AdamW @torch.no_grad() def adafactor_step_param(self, p, group): @@ -81,9 +81,57 @@ def adafactor_step_param(self, p, group): if p.dtype in {torch.float16, torch.bfloat16}: p.copy_(p_data_fp32) +@torch.no_grad() +def adamw_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + # if group["correct_bias"]: # No bias correction for Bert + # bias_correction1 = 1.0 - beta1 ** state["step"] + # bias_correction2 = 1.0 - beta2 ** state["step"] + # step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + @torch.no_grad() -def adafactor_step(self, closure=None): +def optimizer_step(self, optimizer_step_param, closure=None): """ Performs a single optimization step @@ -97,10 +145,18 @@ def adafactor_step(self, closure=None): for group in self.param_groups: for p in group["params"]: - adafactor_step_param(self, p, group) + optimizer_step_param(self, p, group) return loss -def patch_adafactor_fused(optimizer: Adafactor): - optimizer.step_param = adafactor_step_param.__get__(optimizer) - optimizer.step = adafactor_step.__get__(optimizer) +def patch_optimizer_fused(optimizer, optimizer_type): + print(type(optimizer)) + if optimizer_type.lower()=='adamw': + print("Using AdamW Fused") + optimizer.step_param = adamw_step_param.__get__(optimizer) + optimizer.step = optimizer_step.__get__(optimizer, adamw_step_param) + if optimizer_type.lower()=='adafactor': + print("Using Adafactor Fused") + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = optimizer_step.__get__(optimizer, adafactor_step_param) + \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index 46b55c03e..74f950464 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3852,9 +3852,10 @@ def get_optimizer(args, trainable_params): optimizer_type = optimizer_type.lower() if args.fused_backward_pass: + accepted_optimizers=["Adafactor","AdamW"] assert ( - optimizer_type == "Adafactor".lower() - ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + optimizer_type in [optimizer.lower() for optimizer in accepted_optimizers] + ), f"fused_backward_pass currently only works with optimizer_type in {accepted_optimizers} / fused_backward_passは現在optimizer_type {accepted_optimizers}でのみ機能します" assert ( args.gradient_accumulation_steps == 1 ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" diff --git a/sdxl_train.py b/sdxl_train.py index 3b28575ed..fd4ce3363 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -431,8 +431,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) if args.fused_backward_pass: - import library.adafactor_fused - library.adafactor_fused.patch_adafactor_fused(optimizer) + import library.optimizer_fused + library.optimizer_fused.patch_optimizer_fused(optimizer, args.optimizer_type) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: