Skip to content

Commit bf6cd4b

Browse files
authored
Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
2 parents 3b0db0f + d282c45 commit bf6cd4b

9 files changed

Lines changed: 13 additions & 8 deletions

fine_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
457457

458458
accelerator.end_training()
459459

460-
if args.save_state and is_main_process:
460+
if is_main_process and (args.save_state or args.save_state_on_train_end):
461461
train_util.save_state_on_train_end(args, accelerator)
462462

463463
del accelerator # この後メモリを使うのでこれは消す

library/train_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,6 +2938,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
29382938
action="store_true",
29392939
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
29402940
)
2941+
parser.add_argument(
2942+
"--save_state_on_train_end",
2943+
action="store_true",
2944+
help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する",
2945+
)
29412946
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
29422947

29432948
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")

sdxl_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
712712

713713
accelerator.end_training()
714714

715-
if args.save_state: # and is_main_process:
715+
if args.save_state or args.save_state_on_train_end:
716716
train_util.save_state_on_train_end(args, accelerator)
717717

718718
del accelerator # この後メモリを使うのでこれは消す

sdxl_train_control_net_lllite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def remove_model(old_ckpt_name):
549549

550550
accelerator.end_training()
551551

552-
if is_main_process and args.save_state:
552+
if is_main_process and (args.save_state or args.save_state_on_train_end):
553553
train_util.save_state_on_train_end(args, accelerator)
554554

555555
if is_main_process:

train_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def remove_model(old_ckpt_name):
565565

566566
accelerator.end_training()
567567

568-
if is_main_process and args.save_state:
568+
if is_main_process and (args.save_state or args.save_state_on_train_end):
569569
train_util.save_state_on_train_end(args, accelerator)
570570

571571
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく

train_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def train(args):
444444

445445
accelerator.end_training()
446446

447-
if args.save_state and is_main_process:
447+
if is_main_process and (args.save_state or args.save_state_on_train_end):
448448
train_util.save_state_on_train_end(args, accelerator)
449449

450450
del accelerator # この後メモリを使うのでこれは消す

train_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ def remove_model(old_ckpt_name):
940940

941941
accelerator.end_training()
942942

943-
if is_main_process and args.save_state:
943+
if is_main_process and (args.save_state or args.save_state_on_train_end):
944944
train_util.save_state_on_train_end(args, accelerator)
945945

946946
if is_main_process:

train_textual_inversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def remove_model(old_ckpt_name):
732732

733733
accelerator.end_training()
734734

735-
if args.save_state and is_main_process:
735+
if is_main_process and (args.save_state or args.save_state_on_train_end):
736736
train_util.save_state_on_train_end(args, accelerator)
737737

738738
if is_main_process:

train_textual_inversion_XTI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def remove_model(old_ckpt_name):
586586

587587
accelerator.end_training()
588588

589-
if args.save_state and is_main_process:
589+
if is_main_process and (args.save_state or args.save_state_on_train_end):
590590
train_util.save_state_on_train_end(args, accelerator)
591591

592592
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()

0 commit comments

Comments
 (0)