Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,11 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--validation_prompt_library",
action="store_true",
help="If this is provided, the SimpleTuner prompt library will be used to generate multiple images.",
)
parser.add_argument(
"--num_validation_images",
type=int,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from accelerate import Accelerator
from diffusers import DiffusionPipeline, UNet2DConditionModel, DDPMScheduler, DDIMScheduler
from transformers import CLIPTextModel
from prompts import prompts
from helpers.prompts import prompts
from compel import Compel

import torch, os, logging
Expand Down
Empty file added toolkit/datasets/README.md
Empty file.
2 changes: 1 addition & 1 deletion toolkit/inference/inference_snr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from PIL import Image
from diffusers import StableDiffusionPipeline, DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DDIMScheduler
from transformers import CLIPTextModel
from prompts import prompts
from helpers.prompts import prompts
model_id = '/notebooks/datasets/models/pseudo-realism'
#model_id = 'stabilityai/stable-diffusion-2-1'
pipe = StableDiffusionPipeline.from_pretrained(model_id)
Expand Down
2 changes: 1 addition & 1 deletion toolkit/inference/tile_shortnames.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from PIL import Image, ImageDraw, ImageFont
from prompts import prompts
from helpers.prompts import prompts
grid_dir = '/notebooks/SimpleTuner/grid'
output_dir = '/notebooks/datasets/test_results'

Expand Down
141 changes: 87 additions & 54 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def main():
)

# Create EMA for the unet.
ema_unet = None
if args.use_ema:
logger.info("Using EMA. Creating EMAModel.")
ema_unet = EMAModel(
Expand Down Expand Up @@ -571,11 +572,28 @@ def collate_fn(examples):
logger.info(f"Pre-computing text embeds / updating cache.")
embed_cache.precompute_embeddings_for_prompts(train_dataset.get_all_captions())

if args.validation_prompt is not None:
validation_prompts = []
validation_shortnames = []
if args.validation_prompt_library:
# Use the SimpleTuner prompts library for validation prompts.
from helpers.prompts import prompts as prompt_library
# Prompt format: { 'shortname': 'this is the prompt', ... }
for shortname, prompt in prompt_library.items():
logger.info(f'Precomputing validation prompt embeds: {shortname}')
embed_cache.compute_embeddings_for_prompts([prompt])
validation_prompts.append(prompt)
validation_shortnames.append(shortname)
elif args.validation_prompt is not None:
# Use a single prompt for validation.
validation_prompts = [args.validation_prompt]
validation_shortnames = ['validation']
(
validation_prompt_embeds,
validation_pooled_embeds,
) = embed_cache.compute_embeddings_for_prompts([args.validation_prompt])

# Compute negative embed for validation prompts, if any are set.
if validation_prompts:
(
validation_negative_prompt_embeds,
validation_negative_pooled_embeds,
Expand Down Expand Up @@ -885,8 +903,7 @@ def collate_fn(examples):

### BEGIN: Perform validation every `validation_epochs` steps
if accelerator.is_main_process:
if global_step % args.validation_steps == 0 and global_step > 1:
pass
if validation_prompts and global_step % args.validation_steps == 0 and global_step > 1:
if (
args.validation_prompt is None
or args.num_validation_images is None
Expand All @@ -896,7 +913,9 @@ def collate_fn(examples):
f"Not generating any validation images for this checkpoint. Live dangerously and prosper, pal!"
)
continue

if args.gradient_accumulation_steps > 0 and step % args.gradient_accumulation_steps != 0:
# We do not want to perform validation on a partial batch.
continue
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
Expand All @@ -919,7 +938,7 @@ def collate_fn(examples):
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler.config.prediction_type = "v_prediction"
pipeline.scheduler.config.prediction_type = args.prediction_type or noise_scheduler.config.prediction_type
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

Expand All @@ -936,22 +955,39 @@ def collate_fn(examples):
or accelerator.mixed_precision == "bf16"
),
):
validation_generator = torch.Generator(
device=accelerator.device
).manual_seed(args.seed or 0)
validation_images = pipeline(
prompt_embeds=validation_prompt_embeds,
pooled_prompt_embeds=validation_pooled_embeds,
negative_prompt_embeds=validation_negative_prompt_embeds,
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
num_images_per_prompt=args.num_validation_images,
num_inference_steps=30,
guidance_scale=args.validation_guidance,
guidance_rescale=args.validation_guidance_rescale,
generator=validation_generator,
height=args.validation_resolution,
width=args.validation_resolution,
).images
validation_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
validation_generator = torch.Generator(
device=accelerator.device
).manual_seed(args.seed or 0)
for validation_prompt in validation_prompts:
# Each validation prompt needs its own embed.
current_validation_prompt_embeds, current_validation_pooled_embeds = embed_cache.compute_embeddings_for_prompts(
[validation_prompt]
)
logger.info(f'Generating validation image: {validation_prompt}')
validation_images.extend(pipeline(
prompt_embeds=current_validation_prompt_embeds,
pooled_prompt_embeds=current_validation_pooled_embeds,
negative_prompt_embeds=validation_negative_prompt_embeds,
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
num_images_per_prompt=args.num_validation_images,
num_inference_steps=30,
guidance_scale=args.validation_guidance,
guidance_rescale=args.validation_guidance_rescale,
generator=validation_generator,
height=args.validation_resolution,
width=args.validation_resolution,
).images)

for tracker in accelerator.trackers:
if tracker.name == "wandb":
validation_document = {}
for idx, validation_image in enumerate(validation_images):
# Create a WandB entry containing each image.
validation_document[validation_shortnames[idx]] = wandb.Image(validation_image)
tracker.log(validation_document, step=global_step)
val_img_idx = 0
for a_val_img in validation_images:
a_val_img.save(
Expand All @@ -962,14 +998,6 @@ def collate_fn(examples):
)
val_img_idx += 1

for tracker in accelerator.trackers:
if tracker.name == "wandb":
idx = 0
for validation_image in validation_images:
tracker.log(
{f"image-{idx}": wandb.Image(validation_images[idx])}
)
idx += 1
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
Expand Down Expand Up @@ -998,7 +1026,8 @@ def collate_fn(examples):
unet=unet,
revision=args.revision,
)
pipeline.save_pretrained("/notebooks/datasets/models/ptx0-xltest")
pipeline.scheduler.config = noise_scheduler.config
pipeline.save_pretrained("/notebooks/datasets/models/ptx0-xltest", safe_serialization=True)

if args.push_to_hub:
upload_folder(
Expand All @@ -1008,36 +1037,40 @@ def collate_fn(examples):
ignore_patterns=["step_*", "epoch_*"],
)

if args.validation_prompt is not None:
if validation_prompts:
validation_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
validation_generator = torch.Generator(
device=accelerator.device
).manual_seed(args.seed or 0)
validation_images = pipeline(
prompt_embeds=validation_prompt_embeds,
pooled_prompt_embeds=validation_pooled_embeds,
negative_prompt_embeds=validation_negative_prompt_embeds,
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
num_images_per_prompt=args.num_validation_images,
num_inference_steps=30,
guidance_scale=args.validation_guidance,
guidance_rescale=args.validation_guidance_rescale,
generator=validation_generator,
height=args.validation_resolution,
width=args.validation_resolution,
).images

for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
idx = 0
for validation_image in validation_images:
tracker.log(
{f"image-{idx}": wandb.Image(validation_images[idx])}
)
idx += 1
for validation_prompt in validation_prompts:
# Each validation prompt needs its own embed.
current_validation_prompt_embeds, current_validation_pooled_embeds = embed_cache.compute_embeddings_for_prompts(
[validation_prompt]
)
validation_images.extend(pipeline(
prompt_embeds=current_validation_prompt_embeds,
pooled_prompt_embeds=current_validation_pooled_embeds,
negative_prompt_embeds=validation_negative_prompt_embeds,
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
num_images_per_prompt=args.num_validation_images,
num_inference_steps=30,
guidance_scale=args.validation_guidance,
guidance_rescale=args.validation_guidance_rescale,
generator=validation_generator,
height=args.validation_resolution,
width=args.validation_resolution,
).images)

for tracker in accelerator.trackers:
if tracker.name == "wandb":
validation_document = {}
for idx, validation_image in enumerate(validation_images):
# Create a WandB entry containing each image.
validation_document[validation_shortnames[idx]] = wandb.Image(validation_image)
tracker.log(validation_document, step=global_step)

accelerator.end_training()


Expand Down