@@ -210,6 +210,7 @@ def main():
210210 )
211211
212212 # Create EMA for the unet.
213+ ema_unet = None
213214 if args .use_ema :
214215 logger .info ("Using EMA. Creating EMAModel." )
215216 ema_unet = EMAModel (
@@ -571,11 +572,28 @@ def collate_fn(examples):
571572 logger .info (f"Pre-computing text embeds / updating cache." )
572573 embed_cache .precompute_embeddings_for_prompts (train_dataset .get_all_captions ())
573574
574- if args .validation_prompt is not None :
575+ validation_prompts = []
576+ validation_shortnames = []
577+ if args .validation_prompt_library :
578+ # Use the SimpleTuner prompts library for validation prompts.
579+ from helpers .prompts import prompts as prompt_library
580+ # Prompt format: { 'shortname': 'this is the prompt', ... }
581+ for shortname , prompt in prompt_library .items ():
582+ logger .info (f'Precomputing validation prompt embeds: { shortname } ' )
583+ embed_cache .compute_embeddings_for_prompts ([prompt ])
584+ validation_prompts .append (prompt )
585+ validation_shortnames .append (shortname )
586+ elif args .validation_prompt is not None :
587+ # Use a single prompt for validation.
588+ validation_prompts = [args .validation_prompt ]
589+ validation_shortnames = ['validation' ]
575590 (
576591 validation_prompt_embeds ,
577592 validation_pooled_embeds ,
578593 ) = embed_cache .compute_embeddings_for_prompts ([args .validation_prompt ])
594+
595+ # Compute negative embed for validation prompts, if any are set.
596+ if validation_prompts :
579597 (
580598 validation_negative_prompt_embeds ,
581599 validation_negative_pooled_embeds ,
@@ -885,8 +903,7 @@ def collate_fn(examples):
885903
886904 ### BEGIN: Perform validation every `validation_epochs` steps
887905 if accelerator .is_main_process :
888- if global_step % args .validation_steps == 0 and global_step > 1 :
889- pass
906+ if validation_prompts and global_step % args .validation_steps == 0 and global_step > 1 :
890907 if (
891908 args .validation_prompt is None
892909 or args .num_validation_images is None
@@ -896,7 +913,9 @@ def collate_fn(examples):
896913 f"Not generating any validation images for this checkpoint. Live dangerously and prosper, pal!"
897914 )
898915 continue
899-
916+ if args .gradient_accumulation_steps > 0 and step % args .gradient_accumulation_steps != 0 :
917+ # We do not want to perform validation on a partial batch.
918+ continue
900919 logger .info (
901920 f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
902921 f" { args .validation_prompt } ."
@@ -919,7 +938,7 @@ def collate_fn(examples):
919938 revision = args .revision ,
920939 torch_dtype = weight_dtype ,
921940 )
922- pipeline .scheduler .config .prediction_type = "v_prediction"
941+ pipeline .scheduler .config .prediction_type = args . prediction_type or noise_scheduler . config . prediction_type
923942 pipeline = pipeline .to (accelerator .device )
924943 pipeline .set_progress_bar_config (disable = True )
925944
@@ -936,22 +955,39 @@ def collate_fn(examples):
936955 or accelerator .mixed_precision == "bf16"
937956 ),
938957 ):
939- validation_generator = torch .Generator (
940- device = accelerator .device
941- ).manual_seed (args .seed or 0 )
942- validation_images = pipeline (
943- prompt_embeds = validation_prompt_embeds ,
944- pooled_prompt_embeds = validation_pooled_embeds ,
945- negative_prompt_embeds = validation_negative_prompt_embeds ,
946- negative_pooled_prompt_embeds = validation_negative_pooled_embeds ,
947- num_images_per_prompt = args .num_validation_images ,
948- num_inference_steps = 30 ,
949- guidance_scale = args .validation_guidance ,
950- guidance_rescale = args .validation_guidance_rescale ,
951- generator = validation_generator ,
952- height = args .validation_resolution ,
953- width = args .validation_resolution ,
954- ).images
958+ validation_images = []
959+ pipeline = pipeline .to (accelerator .device )
960+ with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
961+ validation_generator = torch .Generator (
962+ device = accelerator .device
963+ ).manual_seed (args .seed or 0 )
964+ for validation_prompt in validation_prompts :
965+ # Each validation prompt needs its own embed.
966+ current_validation_prompt_embeds , current_validation_pooled_embeds = embed_cache .compute_embeddings_for_prompts (
967+ [validation_prompt ]
968+ )
969+ logger .info (f'Generating validation image: { validation_prompt } ' )
970+ validation_images .extend (pipeline (
971+ prompt_embeds = current_validation_prompt_embeds ,
972+ pooled_prompt_embeds = current_validation_pooled_embeds ,
973+ negative_prompt_embeds = validation_negative_prompt_embeds ,
974+ negative_pooled_prompt_embeds = validation_negative_pooled_embeds ,
975+ num_images_per_prompt = args .num_validation_images ,
976+ num_inference_steps = 30 ,
977+ guidance_scale = args .validation_guidance ,
978+ guidance_rescale = args .validation_guidance_rescale ,
979+ generator = validation_generator ,
980+ height = args .validation_resolution ,
981+ width = args .validation_resolution ,
982+ ).images )
983+
984+ for tracker in accelerator .trackers :
985+ if tracker .name == "wandb" :
986+ validation_document = {}
987+ for idx , validation_image in enumerate (validation_images ):
988+ # Create a WandB entry containing each image.
989+ validation_document [validation_shortnames [idx ]] = wandb .Image (validation_image )
990+ tracker .log (validation_document , step = global_step )
955991 val_img_idx = 0
956992 for a_val_img in validation_images :
957993 a_val_img .save (
@@ -962,14 +998,6 @@ def collate_fn(examples):
962998 )
963999 val_img_idx += 1
9641000
965- for tracker in accelerator .trackers :
966- if tracker .name == "wandb" :
967- idx = 0
968- for validation_image in validation_images :
969- tracker .log (
970- {f"image-{ idx } " : wandb .Image (validation_images [idx ])}
971- )
972- idx += 1
9731001 if args .use_ema :
9741002 # Switch back to the original UNet parameters.
9751003 ema_unet .restore (unet .parameters ())
@@ -998,7 +1026,8 @@ def collate_fn(examples):
9981026 unet = unet ,
9991027 revision = args .revision ,
10001028 )
1001- pipeline .save_pretrained ("/notebooks/datasets/models/ptx0-xltest" )
1029+ pipeline .scheduler .config = noise_scheduler .config
1030+ pipeline .save_pretrained ("/notebooks/datasets/models/ptx0-xltest" , safe_serialization = True )
10021031
10031032 if args .push_to_hub :
10041033 upload_folder (
@@ -1008,36 +1037,40 @@ def collate_fn(examples):
10081037 ignore_patterns = ["step_*" , "epoch_*" ],
10091038 )
10101039
1011- if args . validation_prompt is not None :
1040+ if validation_prompts :
10121041 validation_images = []
10131042 pipeline = pipeline .to (accelerator .device )
10141043 with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
10151044 validation_generator = torch .Generator (
10161045 device = accelerator .device
10171046 ).manual_seed (args .seed or 0 )
1018- validation_images = pipeline (
1019- prompt_embeds = validation_prompt_embeds ,
1020- pooled_prompt_embeds = validation_pooled_embeds ,
1021- negative_prompt_embeds = validation_negative_prompt_embeds ,
1022- negative_pooled_prompt_embeds = validation_negative_pooled_embeds ,
1023- num_images_per_prompt = args .num_validation_images ,
1024- num_inference_steps = 30 ,
1025- guidance_scale = args .validation_guidance ,
1026- guidance_rescale = args .validation_guidance_rescale ,
1027- generator = validation_generator ,
1028- height = args .validation_resolution ,
1029- width = args .validation_resolution ,
1030- ).images
1031-
1032- for tracker in accelerator .trackers :
1033- if tracker .name == "wandb" :
1034- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1035- idx = 0
1036- for validation_image in validation_images :
1037- tracker .log (
1038- {f"image-{ idx } " : wandb .Image (validation_images [idx ])}
1039- )
1040- idx += 1
1047+ for validation_prompt in validation_prompts :
1048+ # Each validation prompt needs its own embed.
1049+ current_validation_prompt_embeds , current_validation_pooled_embeds = embed_cache .compute_embeddings_for_prompts (
1050+ [validation_prompt ]
1051+ )
1052+ validation_images .extend (pipeline (
1053+ prompt_embeds = current_validation_prompt_embeds ,
1054+ pooled_prompt_embeds = current_validation_pooled_embeds ,
1055+ negative_prompt_embeds = validation_negative_prompt_embeds ,
1056+ negative_pooled_prompt_embeds = validation_negative_pooled_embeds ,
1057+ num_images_per_prompt = args .num_validation_images ,
1058+ num_inference_steps = 30 ,
1059+ guidance_scale = args .validation_guidance ,
1060+ guidance_rescale = args .validation_guidance_rescale ,
1061+ generator = validation_generator ,
1062+ height = args .validation_resolution ,
1063+ width = args .validation_resolution ,
1064+ ).images )
1065+
1066+ for tracker in accelerator .trackers :
1067+ if tracker .name == "wandb" :
1068+ validation_document = {}
1069+ for idx , validation_image in enumerate (validation_images ):
1070+ # Create a WandB entry containing each image.
1071+ validation_document [validation_shortnames [idx ]] = wandb .Image (validation_image )
1072+ tracker .log (validation_document , step = global_step )
1073+
10411074 accelerator .end_training ()
10421075
10431076
0 commit comments