diff --git a/fine_tune.py b/fine_tune.py index 875a91951..46f128287 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -457,7 +457,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..b3ca15f55 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2890,6 +2890,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する", + ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d6..107bb9451 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -712,7 +712,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state: # and is_main_process: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 1e5f92349..e99b4e35c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -549,7 +549,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_controlnet.py b/train_controlnet.py index dc73a91c8..e44f08853 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -565,7 +565,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく diff --git a/train_db.py b/train_db.py index 8d36097a5..41a9a7b99 100644 --- a/train_db.py +++ b/train_db.py @@ -444,7 +444,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/train_network.py b/train_network.py index e0fa69458..3db583f1d 100644 --- a/train_network.py +++ b/train_network.py @@ -935,7 +935,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index df1d8485a..0266bc143 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -732,7 +732,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 695fad2a8..ad7c267eb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -586,7 +586,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()