diff --git a/fine_tune.py b/fine_tune.py index 52e84c43f..1059bbc5f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -292,11 +292,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + train_util.init_trackers(accelerator, "finetuning", args) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): diff --git a/library/train_util.py b/library/train_util.py index cc9ac4555..7c7466dfc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4721,6 +4721,35 @@ def __getitem__(self, idx): # endregion +def init_trackers(accelerator, project_name, args): + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + project_name=project_name if args.log_tracker_name is None else args.log_tracker_name, + config=clean_config_args(args), + init_kwargs=init_kwargs + ) + + +def clean_config_args(args): + result = {} + for k, v in vars(args).items(): + if v is None: + result[k] = v + # tensorboard does not support lists + elif isinstance(v, list): + result[k] = f"{v}" + # tensorboard does not support objects + elif isinstance(v, object): + result[k] = f"{v}" + else: + result[k] = v + + return result + + # collate_fn用 epoch,stepはmultiprocessing.Value class collator_class: def __init__(self, epoch, step, dataset): diff --git a/sdxl_train.py b/sdxl_train.py index fd775624e..abfbc9551 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -455,11 +455,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + train_util.init_trackers(accelerator, "finetuning", args) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697c..04e0af556 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -343,13 +343,7 @@ def train(args): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) + train_util.init_trackers(accelerator, "network_train", args) loss_recorder = train_util.LossRecorder() del train_dataset_group diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10eaa..8d3003f99 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -316,13 +316,7 @@ def train(args): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) + train_util.init_trackers(accelerator, "lllite_control_net_train", args) loss_recorder = train_util.LossRecorder() del train_dataset_group diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..2972333af 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -331,11 +331,8 @@ def train(args): num_train_timesteps=1000, clip_sample=False, ) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + + train_util.init_trackers(accelerator, "controlnet_train", args) loss_recorder = train_util.LossRecorder() del train_dataset_group diff --git a/train_db.py b/train_db.py index 7fbbc18ac..6b27a6095 100644 --- a/train_db.py +++ b/train_db.py @@ -268,11 +268,7 @@ def train(args): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + train_util.init_trackers(accelerator, "dreambooth", args) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): diff --git a/train_network.py b/train_network.py index d50916b74..f7c3257b5 100644 --- a/train_network.py +++ b/train_network.py @@ -702,13 +702,7 @@ def train(self, args): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) + train_util.init_trackers(accelerator, "network_train", args) loss_recorder = train_util.LossRecorder() del train_dataset_group diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..96cea56d7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -502,13 +502,7 @@ def train(self, args): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) + train_util.init_trackers(accelerator, "textual_inversion", args) # function for saving/removing def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False): diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8dd5c672f..25549bc83 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -395,11 +395,7 @@ def train(args): if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + train_util.init_trackers(accelerator, "textual_inversion", args) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):