Skip to content

Commit a61603d

Browse files
bghirabghira
andauthored
New dataloader arguments for deleting problematic images (#30)
* DataLoader: make it optional to delete unwanted images, off by default * Arguments: add terminal SNR parameters for tweaking, rather than being baked-in * Make terminal SNR opt-in --------- Co-authored-by: bghira <[email protected]>
1 parent c1a61bb commit a61603d

File tree

4 files changed

+103
-33
lines changed

4 files changed

+103
-33
lines changed

helpers/arguments.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33

44
def parse_args(input_args=None):
5-
parser = argparse.ArgumentParser(description="The following SimpleTuner command-line options are available:")
5+
parser = argparse.ArgumentParser(
6+
description="The following SimpleTuner command-line options are available:"
7+
)
68
parser.add_argument(
79
"--snr_gamma",
810
type=float,
@@ -34,6 +36,34 @@ def parse_args(input_args=None):
3436
" SD 1.5 is epsilon."
3537
),
3638
)
39+
parser.add_argument(
40+
'--training_scheduler_timestep_spacing',
41+
type=str,
42+
default="leading",
43+
choices=["leading", "linspace", "trailing"],
44+
help=(
45+
"Spacing timesteps can fundamentally alter the course of history. Er, I mean, your model weights."
46+
" For all training, including terminal SNR, it would seem that 'leading' is the right choice."
47+
" However, for inference in terminal SNR models, 'trailing' is the correct choice."
48+
)
49+
)
50+
parser.add_argument(
51+
'--inference_scheduler_timestep_spacing',
52+
type=str,
53+
default="trailing",
54+
choices=["leading", "linspace", "trailing"],
55+
help=(
56+
"The Bytedance paper on zero terminal SNR recommends inference using 'trailing'."
57+
)
58+
)
59+
parser.add_argument(
60+
'--rescale_betas_zero_snr',
61+
action="store_true",
62+
help=(
63+
"If set, will rescale the betas to zero terminal SNR. This is recommended for training with v_prediction."
64+
" For epsilon, this might help with fine details, but will not result in contrast improvements."
65+
)
66+
)
3767
parser.add_argument(
3868
"--vae_dtype",
3969
type=str,
@@ -113,13 +143,13 @@ def parse_args(input_args=None):
113143
"--seen_state_path",
114144
type=str,
115145
default=None,
116-
help="Where the JSON document containing the state of the seen images is stored. This helps ensure we do not repeat images too many times."
146+
help="Where the JSON document containing the state of the seen images is stored. This helps ensure we do not repeat images too many times.",
117147
)
118148
parser.add_argument(
119149
"--state_path",
120150
type=str,
121151
default=None,
122-
help="A JSON document containing the current state of training, will be placed here."
152+
help="A JSON document containing the current state of training, will be placed here.",
123153
)
124154
parser.add_argument(
125155
"--caption_strategy",
@@ -156,6 +186,15 @@ def parse_args(input_args=None):
156186
" resolution"
157187
),
158188
)
189+
parser.add_argument(
190+
"--minimum_image_size",
191+
type=int,
192+
default=768,
193+
help=(
194+
"The minimum resolution for both sides of input images."
195+
" If --delete_unwanted_images is set, images smaller than this will be DELETED."
196+
),
197+
)
159198
parser.add_argument(
160199
"--crops_coords_top_left_h",
161200
type=int,
@@ -235,9 +274,7 @@ def parse_args(input_args=None):
235274
"--checkpoints_total_limit",
236275
type=int,
237276
default=None,
238-
help=(
239-
"Max number of checkpoints to store."
240-
),
277+
help=("Max number of checkpoints to store."),
241278
)
242279
parser.add_argument(
243280
"--resume_from_checkpoint",
@@ -299,7 +336,9 @@ def parse_args(input_args=None):
299336
help="Power factor of the polynomial scheduler.",
300337
)
301338
parser.add_argument(
302-
"--use_ema", action="store_true", help="Whether to use EMA (exponential moving average) model."
339+
"--use_ema",
340+
action="store_true",
341+
help="Whether to use EMA (exponential moving average) model.",
303342
)
304343
parser.add_argument(
305344
"--non_ema_revision",
@@ -485,13 +524,13 @@ def parse_args(input_args=None):
485524
help="Run validation every X epochs.",
486525
)
487526
parser.add_argument(
488-
'--validation_guidance',
527+
"--validation_guidance",
489528
type=float,
490529
default=7.5,
491530
help="CFG value for validation images. Default: 7.5",
492531
)
493532
parser.add_argument(
494-
'--validation_guidance_rescale',
533+
"--validation_guidance_rescale",
495534
type=float,
496535
default=0.0,
497536
help="CFG rescale value for validation images. Default: 0.0, max 1.0",
@@ -593,7 +632,15 @@ def parse_args(input_args=None):
593632
help=(
594633
"When this option is provided, image cropping and processing will be disabled."
595634
" It is a good idea to use this with caution, for training multiple aspect ratios."
596-
)
635+
),
636+
)
637+
parser.add_argument(
638+
"--delete_unwanted_images",
639+
action="store_true",
640+
help=(
641+
"If set, will delete images that are not of a minimum size to save on disk space for large training runs."
642+
" Default behaviour: Unset, remove images from bucket only."
643+
),
597644
)
598645
parser.add_argument(
599646
"--offset_noise",

helpers/aspect_bucket.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,38 @@
2020
class BalancedBucketSampler(torch.utils.data.Sampler):
2121
def __init__(
2222
self,
23-
aspect_ratio_bucket_indices,
23+
aspect_ratio_bucket_indices: dict,
2424
batch_size: int = 15,
2525
seen_images_path: str = "/notebooks/SimpleTuner/seen_images.json",
2626
state_path: str = "/notebooks/SimpleTuner/bucket_sampler_state.json",
27-
reset_threshold: int = 5000, # Add a reset_threshold
27+
reset_threshold: int = 5000,
2828
debug_aspect_buckets: bool = False,
29+
delete_unwanted_images: bool = False,
30+
minimum_image_size: int = None,
2931
):
3032
"""
31-
Initialize the BalancedBucketSampler instance.
33+
Initializes the sampler with provided settings.
3234
33-
Args:
34-
aspect_ratio_bucket_indices (dict): A dictionary mapping aspect ratios to image paths.
35-
batch_size (int): The number of images per sample during training.
36-
seen_images_path (str): The path to save/load the seen images.
37-
state_path (str): The path to save/load the state of the sampler.
38-
reset_threshold (int): The number of seen images to trigger a reset.
39-
debug_aspect_buckets (bool): If True, enable debug logging.
35+
Parameters:
36+
- aspect_ratio_bucket_indices: Dictionary containing aspect ratios as keys and list of image paths as values.
37+
- batch_size: Number of samples to draw per batch.
38+
- seen_images_path: Path to store the seen images.
39+
- state_path: Path to store the current state of the sampler.
40+
- reset_threshold: The threshold after which the seen images list should be reset.
41+
- debug_aspect_buckets: Flag to log state for debugging purposes.
42+
- delete_unwanted_images: Flag to decide whether to delete unwanted (small) images or just remove from the bucket.
4043
"""
4144
self.aspect_ratio_bucket_indices = aspect_ratio_bucket_indices
42-
self.buckets = self.load_buckets()
45+
self.buckets = list(self.aspect_ratio_bucket_indices.keys())
4346
self.exhausted_buckets = []
4447
self.batch_size = batch_size
45-
self.current_bucket = 0
4648
self.seen_images_path = seen_images_path
4749
self.state_path = state_path
4850
self.reset_threshold = reset_threshold
4951
self.debug_aspect_buckets = debug_aspect_buckets
52+
self.delete_unwanted_images = delete_unwanted_images
53+
self.current_bucket = 0
54+
self.minimum_image_size = minimum_image_size
5055
self.seen_images = self.load_seen_images()
5156

5257
def save_state(self):
@@ -85,13 +90,16 @@ def remove_image(self, image_path, bucket):
8590
self.aspect_ratio_bucket_indices[bucket].remove(image_path)
8691

8792
def handle_small_image(self, image_path, bucket):
88-
logger.warning(f"Image too small: DELETING image and continuing search.")
89-
# try:
90-
# os.remove(image_path)
91-
# except Exception as e:
92-
# logger.warning(
93-
# f"The image was already deleted. Another GPU must have gotten to it."
94-
# )
93+
if self.delete_unwanted_images:
94+
try:
95+
logger.warning(f"Image too small: DELETING image and continuing search.")
96+
os.remove(image_path)
97+
except Exception as e:
98+
logger.warning(
99+
f"The image was already deleted. Another GPU must have gotten to it."
100+
)
101+
else:
102+
logger.warning(f"Image too small, but --delete_unwanted_images is not provided, so we simply ignore and remove from bucket.")
95103
self.remove_image(image_path, bucket)
96104

97105
def handle_incorrect_bucket(self, image_path, bucket, actual_bucket):
@@ -194,7 +202,7 @@ def __iter__(self):
194202
except:
195203
logger.warning(f"Image was bad or in-progress: {image_path}")
196204
continue
197-
if image.width < 880 or image.height < 880:
205+
if image.width < self.minimum_image_size or image.height < self.minimum_image_size:
198206
image.close()
199207
self.handle_small_image(image_path, bucket)
200208
continue

sdxl-env.sh.example

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,9 @@ export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema" # anything y
8787

8888
# These are pretty sketchy to change. --use_original_images can be removed to enable image cropping. Not tested for SDXL.
8989
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true"
90-
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --gradient_checkpointing --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS}"
90+
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --gradient_checkpointing --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS}"
91+
92+
## For terminal SNR training:
93+
94+
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --prediction_type=v_prediction --rescale_betas_zero_snr"
95+
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing"

train_sdxl.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,14 @@ def tokenize_captions(captions, tokenizer):
417417
args.pretrained_model_name_or_path,
418418
subfolder="scheduler",
419419
prediction_type=args.prediction_type,
420-
rescale_betas_zero_snr=True,
420+
timestep_spacing=args.training_scheduler_timestep_spacing,
421+
rescale_betas_zero_snr=args.rescale_betas_zero_snr,
421422
)
422423
noise_scheduler = DDPMScheduler.from_pretrained(
423424
args.pretrained_model_name_or_path,
424425
subfolder="scheduler",
425426
prediction_type=args.prediction_type,
427+
timestep_spacing=args.training_scheduler_timestep_spacing,
426428
trained_betas=betas_scheduler.betas.numpy().tolist(),
427429
)
428430
text_encoder_1 = text_encoder_cls_1.from_pretrained(
@@ -554,6 +556,8 @@ def collate_fn(examples):
554556
seen_images_path=args.seen_state_path,
555557
state_path=args.state_path,
556558
debug_aspect_buckets=args.debug_aspect_buckets,
559+
delete_unwanted_images=args.delete_unwanted_images,
560+
minimum_image_size=args.minimum_image_size
557561
)
558562
logger.info("Plugging sampler into dataloader")
559563
train_dataloader = torch.utils.data.DataLoader(
@@ -938,7 +942,13 @@ def collate_fn(examples):
938942
revision=args.revision,
939943
torch_dtype=weight_dtype,
940944
)
941-
pipeline.scheduler.config.prediction_type = args.prediction_type or noise_scheduler.config.prediction_type
945+
pipeline.scheduler = DDIMScheduler.from_pretrained(
946+
args.pretrained_model_name_or_path,
947+
subfolder="scheduler",
948+
prediction_type=args.prediction_type,
949+
timestep_spacing=args.inference_scheduler_timestep_spacing,
950+
rescale_betas_zero_snr=args.rescale_betas_zero_snr,
951+
)
942952
pipeline = pipeline.to(accelerator.device)
943953
pipeline.set_progress_bar_config(disable=True)
944954

0 commit comments

Comments
 (0)