diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5b..6a054e445 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler): +def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss diff --git a/sdxl_train.py b/sdxl_train.py index e0a8f2b2e..b533b2749 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -730,7 +730,7 @@ def optimizer_hook(parameter: torch.Tensor): if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/train_network.py b/train_network.py index 044ec3aa8..7bf125dca 100644 --- a/train_network.py +++ b/train_network.py @@ -998,7 +998,7 @@ def remove_model(old_ckpt_name): if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし