Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,14 @@ def convert_vae_state_dict(vae_state_dict):
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
(
"token_embedding.weight",
"transformer.text_model.embeddings.token_embedding.weight",
),
(
"positional_embedding",
"transformer.text_model.embeddings.position_embedding.weight",
),
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
Expand Down Expand Up @@ -239,19 +245,29 @@ def convert_text_enc_state_dict_v20(text_enc_dict):
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue

relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
relabelled_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], k
)
new_state_dict[relabelled_key] = v

for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
raise Exception(
"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing"
)
relabelled_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], k_pre
)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)

for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
raise Exception(
"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing"
)
relabelled_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], k_pre
)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)

return new_state_dict
Expand All @@ -264,11 +280,27 @@ def convert_text_enc_state_dict(text_enc_dict):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
"--model_path",
default=None,
type=str,
required=True,
help="Path to the model to convert.",
)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the output model.",
)
parser.add_argument(
"--half", action="store_true", help="Save weights in half precision."
)
parser.add_argument(
"--use_safetensors",
action="store_true",
help="Save weights use safetensors, default is ckpt.",
)

args = parser.parse_args()
Expand Down Expand Up @@ -303,7 +335,9 @@ def convert_text_enc_state_dict(text_enc_dict):

# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
unet_state_dict = {
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
}

# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
Expand All @@ -316,10 +350,14 @@ def convert_text_enc_state_dict(text_enc_dict):
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = {
"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()
}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = {
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
}

# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
Expand All @@ -330,4 +368,4 @@ def convert_text_enc_state_dict(text_enc_dict):
save_file(state_dict, args.checkpoint_path)
else:
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
torch.save(state_dict, args.checkpoint_path)
16 changes: 8 additions & 8 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,32 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
'--training_scheduler_timestep_spacing',
"--training_scheduler_timestep_spacing",
type=str,
default="leading",
choices=["leading", "linspace", "trailing"],
help=(
"Spacing timesteps can fundamentally alter the course of history. Er, I mean, your model weights."
" For all training, including terminal SNR, it would seem that 'leading' is the right choice."
" However, for inference in terminal SNR models, 'trailing' is the correct choice."
)
),
)
parser.add_argument(
'--inference_scheduler_timestep_spacing',
"--inference_scheduler_timestep_spacing",
type=str,
default="trailing",
choices=["leading", "linspace", "trailing"],
help=(
"The Bytedance paper on zero terminal SNR recommends inference using 'trailing'."
)
),
)
parser.add_argument(
'--rescale_betas_zero_snr',
"--rescale_betas_zero_snr",
action="store_true",
help=(
"If set, will rescale the betas to zero terminal SNR. This is recommended for training with v_prediction."
" For epsilon, this might help with fine details, but will not result in contrast improvements."
)
),
)
parser.add_argument(
"--vae_dtype",
Expand Down Expand Up @@ -142,13 +142,13 @@ def parse_args(input_args=None):
parser.add_argument(
"--seen_state_path",
type=str,
default='seen_state.json',
default="seen_state.json",
help="Where the JSON document containing the state of the seen images is stored. This helps ensure we do not repeat images too many times.",
)
parser.add_argument(
"--state_path",
type=str,
default='training_state.json',
default="training_state.json",
help="A JSON document containing the current state of training, will be placed here.",
)
parser.add_argument(
Expand Down
24 changes: 19 additions & 5 deletions helpers/aspect_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,18 @@ def remove_image(self, image_path, bucket):
def handle_small_image(self, image_path, bucket):
if self.delete_unwanted_images:
try:
logger.warning(f"Image too small: DELETING image and continuing search.")
logger.warning(
f"Image too small: DELETING image and continuing search."
)
os.remove(image_path)
except Exception as e:
logger.warning(
f"The image was already deleted. Another GPU must have gotten to it."
)
else:
logger.warning(f"Image too small, but --delete_unwanted_images is not provided, so we simply ignore and remove from bucket.")
logger.warning(
f"Image too small, but --delete_unwanted_images is not provided, so we simply ignore and remove from bucket."
)
self.remove_image(image_path, bucket)

def handle_incorrect_bucket(self, image_path, bucket, actual_bucket):
Expand Down Expand Up @@ -138,7 +142,10 @@ def __iter__(self):

bucket = self.buckets[self.current_bucket]

if len(self.buckets) > 1 and len(self.aspect_ratio_bucket_indices[bucket]) < self.batch_size:
if (
len(self.buckets) > 1
and len(self.aspect_ratio_bucket_indices[bucket]) < self.batch_size
):
if bucket not in self.exhausted_buckets:
self.move_to_exhausted()
self.change_bucket()
Expand Down Expand Up @@ -168,7 +175,11 @@ def __iter__(self):
if (len(available_images) < self.batch_size) and (len(self.buckets) == 1):
# We have to check if we have enough 'seen' images, and bring them back.
all_bucket_images = self.aspect_ratio_bucket_indices[bucket]
total = len(self.seen_images) + len(available_images) + len(all_bucket_images)
total = (
len(self.seen_images)
+ len(available_images)
+ len(all_bucket_images)
)
if total < self.batch_size:
logger.warning(
f"Not enough unseen images ({len(available_images)}) in the bucket: {bucket}! Overly-repeating training images."
Expand Down Expand Up @@ -202,7 +213,10 @@ def __iter__(self):
except:
logger.warning(f"Image was bad or in-progress: {image_path}")
continue
if image.width < self.minimum_image_size or image.height < self.minimum_image_size:
if (
image.width < self.minimum_image_size
or image.height < self.minimum_image_size
):
image.close()
self.handle_small_image(image_path, bucket)
continue
Expand Down
13 changes: 8 additions & 5 deletions helpers/broken_images.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from PIL import Image
import os


def handle_broken_images(dir_path, delete=False):
"""Handle broken images in a given directory.

Args:
dir_path (str): The directory path to scan for images.
delete (bool, optional): If True, delete broken images.
delete (bool, optional): If True, delete broken images.
Otherwise, just print their names. Defaults to False.
"""
for filename in os.listdir(dir_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
if filename.lower().endswith(
(".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif")
):
try:
img_path = os.path.join(dir_path, filename)
with Image.open(img_path) as img:
img.verify() # verify that it is, in fact an image
except (IOError, SyntaxError) as e:
logging.info(f'Bad file: {img_path} - {e}')
logging.info(f"Bad file: {img_path} - {e}")
if delete:
os.remove(img_path)
logging.info(f'Removed: {img_path}')
logging.info(f"Removed: {img_path}")
23 changes: 16 additions & 7 deletions helpers/custom_schedule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from torch.optim.lr_scheduler import LambdaLR
import torch


def get_polynomial_decay_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
optimizer,
num_warmup_steps,
num_training_steps,
lr_end=1e-7,
power=1.0,
last_epoch=-1,
):
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
Expand Down Expand Up @@ -33,7 +40,9 @@ def get_polynomial_decay_schedule_with_warmup(

lr_init = optimizer.defaults["lr"]
if not (float(lr_init) > float(lr_end)):
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
raise ValueError(
f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})"
)

def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
Expand All @@ -56,20 +65,20 @@ def enforce_zero_terminal_snr(betas):
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()

# Store old values.
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas


def patch_scheduler_betas(scheduler):
scheduler.betas = enforce_zero_terminal_snr(scheduler.betas)
Loading