Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
train_util.init_trackers(accelerator, "finetuning", args)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down
29 changes: 29 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4721,6 +4721,35 @@ def __getitem__(self, idx):
# endregion


def init_trackers(accelerator, project_name, args):
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
project_name=project_name if args.log_tracker_name is None else args.log_tracker_name,
config=clean_config_args(args),
init_kwargs=init_kwargs
)


def clean_config_args(args):
result = {}
for k, v in vars(args).items():
if v is None:
result[k] = v
# tensorboard does not support lists
elif isinstance(v, list):
result[k] = f"{v}"
# tensorboard does not support objects
elif isinstance(v, object):
result[k] = f"{v}"
else:
result[k] = v

return result


# collate_fn用 epoch,stepはmultiprocessing.Value
class collator_class:
def __init__(self, epoch, step, dataset):
Expand Down
6 changes: 1 addition & 5 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
train_util.init_trackers(accelerator, "finetuning", args)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down
8 changes: 1 addition & 7 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,7 @@ def train(args):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
train_util.init_trackers(accelerator, "network_train", args)

loss_recorder = train_util.LossRecorder()
del train_dataset_group
Expand Down
8 changes: 1 addition & 7 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,7 @@ def train(args):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
train_util.init_trackers(accelerator, "lllite_control_net_train", args)

loss_recorder = train_util.LossRecorder()
del train_dataset_group
Expand Down
7 changes: 2 additions & 5 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,8 @@ def train(args):
num_train_timesteps=1000,
clip_sample=False,
)
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

train_util.init_trackers(accelerator, "controlnet_train", args)

loss_recorder = train_util.LossRecorder()
del train_dataset_group
Expand Down
6 changes: 1 addition & 5 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,7 @@ def train(args):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
train_util.init_trackers(accelerator, "dreambooth", args)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down
8 changes: 1 addition & 7 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,7 @@ def train(self, args):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
train_util.init_trackers(accelerator, "network_train", args)

loss_recorder = train_util.LossRecorder()
del train_dataset_group
Expand Down
8 changes: 1 addition & 7 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,7 @@ def train(self, args):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
train_util.init_trackers(accelerator, "textual_inversion", args)

# function for saving/removing
def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):
Expand Down
6 changes: 1 addition & 5 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,7 @@ def train(args):
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
train_util.init_trackers(accelerator, "textual_inversion", args)

# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
Expand Down