diff --git a/README.md b/README.md index b1b924ec2..d9b2ddc44 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Masked loss + +`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added. + +NOTE: `train_network.py` and `sdxl_train.py` are not tested yet. + +ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). + + ### Working in progress - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md index dbdc1fea2..1f6a78d5c 100644 --- a/docs/train_lllite_README-ja.md +++ b/docs/train_lllite_README-ja.md @@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k ## モデルの学習 ### データセットの準備 -通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。 -たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。 +(finetuning 方式の dataset はサポートしていません。) + +conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 + +たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。 ```toml [[datasets.subsets]] diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index 04dc12da2..a05f87f5f 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1 ### Preparing the dataset -In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. +In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. + +(We do not support the finetuning method dataset.) ```toml [[datasets.subsets]] diff --git a/library/config_util.py b/library/config_util.py index ff4de0921..26daeb472 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -253,9 +253,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -322,7 +323,10 @@ def validate_flex_dataset(dataset_config: dict): self.dataset_schema = validate_flex_dataset elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema + if support_controlnet: + self.dataset_schema = self.cn_dataset_schema + else: + self.dataset_schema = self.db_dataset_schema elif support_finetuning: self.dataset_schema = self.ft_dataset_schema elif support_controlnet: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index a56474622..406e0e36e 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,11 +3,14 @@ import random import re from typing import List, Optional, Union -from .utils import setup_logging +from .utils import setup_logging + setup_logging() -import logging +import logging + logger = logging.getLogger(__name__) + def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return @@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: - snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) else: snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight @@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los loss = loss + loss / scale * v_pred_like_loss return loss + def apply_debiased_estimation(loss, timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1/torch.sqrt(snr_t) + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss + # TODO train_utilと分散しているのでどちらかに寄せる @@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise +def apply_masked_loss(loss, batch): + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + return loss + + """ ########################################## # Perlin Noise diff --git a/library/train_util.py b/library/train_util.py index 6ca5d559e..73cf30a9f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1835,6 +1835,9 @@ def __init__( db_subsets = [] for subset in subsets: + assert ( + not subset.random_crop + ), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません" db_subset = DreamBoothSubset( subset.image_dir, False, @@ -1885,7 +1888,7 @@ def __init__( # assert all conditioning data exists missing_imgs = [] - cond_imgs_with_img = set() + cond_imgs_with_pair = set() for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] subset = None @@ -1899,23 +1902,29 @@ def __init__( logger.warning(f"not directory: {subset.conditioning_data_dir}") continue - img_basename = os.path.basename(info.absolute_path) - ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) - if not os.path.exists(ctrl_img_path): + img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0] + ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename) + if len(ctrl_img_path) < 1: missing_imgs.append(img_basename) + continue + ctrl_img_path = ctrl_img_path[0] + ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path info.cond_img_path = ctrl_img_path - cond_imgs_with_img.add(ctrl_img_path) + cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive extra_imgs = [] for subset in subsets: conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - extra_imgs.extend( - [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] - ) + conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path + extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -3049,6 +3058,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + parser.add_argument( "--ddp_timeout", type=int, @@ -3111,6 +3121,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument( "--noise_offset", type=float, @@ -3284,6 +3295,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) +def add_masked_loss_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train.py b/sdxl_train.py index 2cb80b6b3..816598e04 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -12,6 +12,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -40,6 +41,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -126,7 +128,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -595,9 +597,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss + or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -763,6 +768,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) @@ -799,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) - return parser diff --git a/train_db.py b/train_db.py index bc7ef4ed5..0a152f224 100644 --- a/train_db.py +++ b/train_db.py @@ -14,6 +14,7 @@ from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -34,6 +35,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -60,7 +62,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -357,6 +359,8 @@ def train(args): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -482,6 +486,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) diff --git a/train_network.py b/train_network.py index f2451f412..baf58ad58 100644 --- a/train_network.py +++ b/train_network.py @@ -14,6 +14,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -40,6 +41,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -159,7 +161,7 @@ def train(self, args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if use_user_config: logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -852,6 +854,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -975,6 +979,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 72e26a450..e7083596f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -9,6 +9,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -30,6 +31,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -269,7 +271,7 @@ def train(self, args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -587,6 +589,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -750,6 +754,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 00b251ba9..861d48d1d 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -10,6 +10,7 @@ import torch from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -32,6 +33,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -201,7 +203,7 @@ def train(args): logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -472,6 +474,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -663,6 +667,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser)