Skip to content

Commit c1a61bb

Browse files
bghirabghira
andauthored
Import SimpleTuner prompt validation library via --validation_prompt_library (#29)
* README for datasets directory * Option to use SimpleTuner prompt library for validations (#28) * Validations: optionally use the SimpleTuner prompt library to validate images * EMA: Allow removal of option, by setting ema_unet to None for SDXLSaveHook * Export: use safetensors * Validations fix: do not perform multiple times on a gradient step --------- Co-authored-by: bghira <[email protected]>
1 parent c550df5 commit c1a61bb

File tree

7 files changed

+95
-57
lines changed

7 files changed

+95
-57
lines changed

helpers/arguments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,11 @@ def parse_args(input_args=None):
414414
default=None,
415415
help="A prompt that is used during validation to verify that the model is learning.",
416416
)
417+
parser.add_argument(
418+
"--validation_prompt_library",
419+
action="store_true",
420+
help="If this is provided, the SimpleTuner prompt library will be used to generate multiple images.",
421+
)
417422
parser.add_argument(
418423
"--num_validation_images",
419424
type=int,
File renamed without changes.

inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from accelerate import Accelerator
22
from diffusers import DiffusionPipeline, UNet2DConditionModel, DDPMScheduler, DDIMScheduler
33
from transformers import CLIPTextModel
4-
from prompts import prompts
4+
from helpers.prompts import prompts
55
from compel import Compel
66

77
import torch, os, logging

toolkit/datasets/README.md

Whitespace-only changes.

toolkit/inference/inference_snr_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from PIL import Image
33
from diffusers import StableDiffusionPipeline, DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DDIMScheduler
44
from transformers import CLIPTextModel
5-
from prompts import prompts
5+
from helpers.prompts import prompts
66
model_id = '/notebooks/datasets/models/pseudo-realism'
77
#model_id = 'stabilityai/stable-diffusion-2-1'
88
pipe = StableDiffusionPipeline.from_pretrained(model_id)

toolkit/inference/tile_shortnames.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from PIL import Image, ImageDraw, ImageFont
3-
from prompts import prompts
3+
from helpers.prompts import prompts
44
grid_dir = '/notebooks/SimpleTuner/grid'
55
output_dir = '/notebooks/datasets/test_results'
66

train_sdxl.py

Lines changed: 87 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def main():
210210
)
211211

212212
# Create EMA for the unet.
213+
ema_unet = None
213214
if args.use_ema:
214215
logger.info("Using EMA. Creating EMAModel.")
215216
ema_unet = EMAModel(
@@ -571,11 +572,28 @@ def collate_fn(examples):
571572
logger.info(f"Pre-computing text embeds / updating cache.")
572573
embed_cache.precompute_embeddings_for_prompts(train_dataset.get_all_captions())
573574

574-
if args.validation_prompt is not None:
575+
validation_prompts = []
576+
validation_shortnames = []
577+
if args.validation_prompt_library:
578+
# Use the SimpleTuner prompts library for validation prompts.
579+
from helpers.prompts import prompts as prompt_library
580+
# Prompt format: { 'shortname': 'this is the prompt', ... }
581+
for shortname, prompt in prompt_library.items():
582+
logger.info(f'Precomputing validation prompt embeds: {shortname}')
583+
embed_cache.compute_embeddings_for_prompts([prompt])
584+
validation_prompts.append(prompt)
585+
validation_shortnames.append(shortname)
586+
elif args.validation_prompt is not None:
587+
# Use a single prompt for validation.
588+
validation_prompts = [args.validation_prompt]
589+
validation_shortnames = ['validation']
575590
(
576591
validation_prompt_embeds,
577592
validation_pooled_embeds,
578593
) = embed_cache.compute_embeddings_for_prompts([args.validation_prompt])
594+
595+
# Compute negative embed for validation prompts, if any are set.
596+
if validation_prompts:
579597
(
580598
validation_negative_prompt_embeds,
581599
validation_negative_pooled_embeds,
@@ -885,8 +903,7 @@ def collate_fn(examples):
885903

886904
### BEGIN: Perform validation every `validation_epochs` steps
887905
if accelerator.is_main_process:
888-
if global_step % args.validation_steps == 0 and global_step > 1:
889-
pass
906+
if validation_prompts and global_step % args.validation_steps == 0 and global_step > 1:
890907
if (
891908
args.validation_prompt is None
892909
or args.num_validation_images is None
@@ -896,7 +913,9 @@ def collate_fn(examples):
896913
f"Not generating any validation images for this checkpoint. Live dangerously and prosper, pal!"
897914
)
898915
continue
899-
916+
if args.gradient_accumulation_steps > 0 and step % args.gradient_accumulation_steps != 0:
917+
# We do not want to perform validation on a partial batch.
918+
continue
900919
logger.info(
901920
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
902921
f" {args.validation_prompt}."
@@ -919,7 +938,7 @@ def collate_fn(examples):
919938
revision=args.revision,
920939
torch_dtype=weight_dtype,
921940
)
922-
pipeline.scheduler.config.prediction_type = "v_prediction"
941+
pipeline.scheduler.config.prediction_type = args.prediction_type or noise_scheduler.config.prediction_type
923942
pipeline = pipeline.to(accelerator.device)
924943
pipeline.set_progress_bar_config(disable=True)
925944

@@ -936,22 +955,39 @@ def collate_fn(examples):
936955
or accelerator.mixed_precision == "bf16"
937956
),
938957
):
939-
validation_generator = torch.Generator(
940-
device=accelerator.device
941-
).manual_seed(args.seed or 0)
942-
validation_images = pipeline(
943-
prompt_embeds=validation_prompt_embeds,
944-
pooled_prompt_embeds=validation_pooled_embeds,
945-
negative_prompt_embeds=validation_negative_prompt_embeds,
946-
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
947-
num_images_per_prompt=args.num_validation_images,
948-
num_inference_steps=30,
949-
guidance_scale=args.validation_guidance,
950-
guidance_rescale=args.validation_guidance_rescale,
951-
generator=validation_generator,
952-
height=args.validation_resolution,
953-
width=args.validation_resolution,
954-
).images
958+
validation_images = []
959+
pipeline = pipeline.to(accelerator.device)
960+
with torch.autocast(str(accelerator.device).replace(":0", "")):
961+
validation_generator = torch.Generator(
962+
device=accelerator.device
963+
).manual_seed(args.seed or 0)
964+
for validation_prompt in validation_prompts:
965+
# Each validation prompt needs its own embed.
966+
current_validation_prompt_embeds, current_validation_pooled_embeds = embed_cache.compute_embeddings_for_prompts(
967+
[validation_prompt]
968+
)
969+
logger.info(f'Generating validation image: {validation_prompt}')
970+
validation_images.extend(pipeline(
971+
prompt_embeds=current_validation_prompt_embeds,
972+
pooled_prompt_embeds=current_validation_pooled_embeds,
973+
negative_prompt_embeds=validation_negative_prompt_embeds,
974+
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
975+
num_images_per_prompt=args.num_validation_images,
976+
num_inference_steps=30,
977+
guidance_scale=args.validation_guidance,
978+
guidance_rescale=args.validation_guidance_rescale,
979+
generator=validation_generator,
980+
height=args.validation_resolution,
981+
width=args.validation_resolution,
982+
).images)
983+
984+
for tracker in accelerator.trackers:
985+
if tracker.name == "wandb":
986+
validation_document = {}
987+
for idx, validation_image in enumerate(validation_images):
988+
# Create a WandB entry containing each image.
989+
validation_document[validation_shortnames[idx]] = wandb.Image(validation_image)
990+
tracker.log(validation_document, step=global_step)
955991
val_img_idx = 0
956992
for a_val_img in validation_images:
957993
a_val_img.save(
@@ -962,14 +998,6 @@ def collate_fn(examples):
962998
)
963999
val_img_idx += 1
9641000

965-
for tracker in accelerator.trackers:
966-
if tracker.name == "wandb":
967-
idx = 0
968-
for validation_image in validation_images:
969-
tracker.log(
970-
{f"image-{idx}": wandb.Image(validation_images[idx])}
971-
)
972-
idx += 1
9731001
if args.use_ema:
9741002
# Switch back to the original UNet parameters.
9751003
ema_unet.restore(unet.parameters())
@@ -998,7 +1026,8 @@ def collate_fn(examples):
9981026
unet=unet,
9991027
revision=args.revision,
10001028
)
1001-
pipeline.save_pretrained("/notebooks/datasets/models/ptx0-xltest")
1029+
pipeline.scheduler.config = noise_scheduler.config
1030+
pipeline.save_pretrained("/notebooks/datasets/models/ptx0-xltest", safe_serialization=True)
10021031

10031032
if args.push_to_hub:
10041033
upload_folder(
@@ -1008,36 +1037,40 @@ def collate_fn(examples):
10081037
ignore_patterns=["step_*", "epoch_*"],
10091038
)
10101039

1011-
if args.validation_prompt is not None:
1040+
if validation_prompts:
10121041
validation_images = []
10131042
pipeline = pipeline.to(accelerator.device)
10141043
with torch.autocast(str(accelerator.device).replace(":0", "")):
10151044
validation_generator = torch.Generator(
10161045
device=accelerator.device
10171046
).manual_seed(args.seed or 0)
1018-
validation_images = pipeline(
1019-
prompt_embeds=validation_prompt_embeds,
1020-
pooled_prompt_embeds=validation_pooled_embeds,
1021-
negative_prompt_embeds=validation_negative_prompt_embeds,
1022-
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
1023-
num_images_per_prompt=args.num_validation_images,
1024-
num_inference_steps=30,
1025-
guidance_scale=args.validation_guidance,
1026-
guidance_rescale=args.validation_guidance_rescale,
1027-
generator=validation_generator,
1028-
height=args.validation_resolution,
1029-
width=args.validation_resolution,
1030-
).images
1031-
1032-
for tracker in accelerator.trackers:
1033-
if tracker.name == "wandb":
1034-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1035-
idx = 0
1036-
for validation_image in validation_images:
1037-
tracker.log(
1038-
{f"image-{idx}": wandb.Image(validation_images[idx])}
1039-
)
1040-
idx += 1
1047+
for validation_prompt in validation_prompts:
1048+
# Each validation prompt needs its own embed.
1049+
current_validation_prompt_embeds, current_validation_pooled_embeds = embed_cache.compute_embeddings_for_prompts(
1050+
[validation_prompt]
1051+
)
1052+
validation_images.extend(pipeline(
1053+
prompt_embeds=current_validation_prompt_embeds,
1054+
pooled_prompt_embeds=current_validation_pooled_embeds,
1055+
negative_prompt_embeds=validation_negative_prompt_embeds,
1056+
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
1057+
num_images_per_prompt=args.num_validation_images,
1058+
num_inference_steps=30,
1059+
guidance_scale=args.validation_guidance,
1060+
guidance_rescale=args.validation_guidance_rescale,
1061+
generator=validation_generator,
1062+
height=args.validation_resolution,
1063+
width=args.validation_resolution,
1064+
).images)
1065+
1066+
for tracker in accelerator.trackers:
1067+
if tracker.name == "wandb":
1068+
validation_document = {}
1069+
for idx, validation_image in enumerate(validation_images):
1070+
# Create a WandB entry containing each image.
1071+
validation_document[validation_shortnames[idx]] = wandb.Image(validation_image)
1072+
tracker.log(validation_document, step=global_step)
1073+
10411074
accelerator.end_training()
10421075

10431076

0 commit comments

Comments
 (0)