Skip to content

Commit e6d74f7

Browse files
authored
[AutoParallel] Refine auto_trainer save load (#8767)
1 parent 5fd6dd2 commit e6d74f7

2 files changed

Lines changed: 79 additions & 59 deletions

File tree

paddlenlp/trainer/auto_trainer.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,10 @@ def _wrap_for_auto(self, model, train_dataloader):
122122
if self.args.to_static:
123123
unified_strategy = dist.Strategy()
124124
unified_strategy._from_legacy_strategy(self.args.strategy)
125-
return (
126-
dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=unified_strategy),
127-
dist_loader,
128-
)
129-
else:
130-
return model, dist_loader
125+
model = dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=unified_strategy)
126+
127+
self.model_wrapped = model
128+
return model, dist_loader
131129

132130
def _wrap_amp_model(self, args, model):
133131
logger.info("Using half precision")
@@ -216,7 +214,6 @@ def _inner_training_loop(
216214
epochs_trained = self.state.global_step // num_update_steps_per_epoch
217215
if not args.ignore_data_skip:
218216
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
219-
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
220217
else:
221218
steps_trained_in_current_epoch = 0
222219

@@ -269,6 +266,9 @@ def _inner_training_loop(
269266
model, dist_loader = self._wrap_for_auto(model, train_dataloader)
270267
train_dataloader = dist_loader()
271268

269+
if resume_from_checkpoint is not None:
270+
self._load_from_checkpoint(resume_from_checkpoint)
271+
272272
self.timers and self.timers("read-data").start()
273273

274274
for epoch in range(epochs_trained, num_train_epochs):
@@ -542,14 +542,26 @@ def _save_checkpoint(self, model, metrics=None):
542542
logger.info(f"Saving checkpoinit files into {output_dir}")
543543

544544
if self.args.should_save_model_state:
545-
546-
optim_state_dict = self.optimizer.state_dict()
547-
optim_state_dict.pop("LR_Scheduler", None)
548-
549-
state_dict = {
550-
MODEL_NAME: self.model.state_dict(),
551-
OPTIMIZER_NAME: optim_state_dict,
552-
}
545+
if self.args.to_static:
546+
state_dict = model.state_dict()
547+
else:
548+
optim_state_dict = self.optimizer.state_dict()
549+
optim_state_dict.pop("LR_Scheduler", None)
550+
opt_state_keys = ["_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0"]
551+
for p_name, p in model.state_dict().items():
552+
if paddle.distributed.get_rank() not in p.process_mesh.process_ids:
553+
var_name = p.name
554+
for key in opt_state_keys:
555+
if (
556+
var_name + key in optim_state_dict
557+
and not optim_state_dict[var_name + key].is_dist()
558+
):
559+
optim_state_dict.pop(var_name + key)
560+
561+
state_dict = {
562+
MODEL_NAME: model.state_dict(),
563+
OPTIMIZER_NAME: optim_state_dict,
564+
}
553565

554566
self._save_ckpt_func(state_dict, os.path.join(output_dir, DIST_CKPT_PATH))
555567
logger.info(f"Model weights and optimizer states saved in {output_dir}/{DIST_CKPT_PATH}")
@@ -584,13 +596,9 @@ def _save_checkpoint(self, model, metrics=None):
584596
rng_states = {
585597
"python": random.getstate(),
586598
"numpy": np.random.get_state(),
587-
"cuda": [k.current_seed() for k in paddle.get_rng_state()],
588-
"cpu": paddle.framework.core.default_cpu_generator().get_state().current_seed(),
599+
"cuda": paddle.get_rng_state(),
600+
"cpu": paddle.framework.core.default_cpu_generator().get_state(),
589601
}
590-
# if self.args.use_hybrid_parallel:
591-
# rng_states[
592-
# "hybrid_parallel_rng_state_tracker"
593-
# ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
594602

595603
if self.args.world_size > 1:
596604
rng_states_list = []
@@ -660,15 +668,23 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
660668
if not os.path.isdir(ckpt_path):
661669
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
662670

663-
optim_state_dict = self.optimizer.state_dict()
664-
optim_state_dict.pop("LR_Scheduler", None)
665-
666-
state_dict = {
667-
MODEL_NAME: self.model.state_dict(),
668-
OPTIMIZER_NAME: optim_state_dict,
669-
}
671+
if self.args.to_static:
672+
state_dict = self.model_wrapped.state_dict()
673+
else:
674+
model_state_dict = self.model_wrapped.state_dict()
675+
optim_state_dict = self.optimizer.state_dict()
676+
optim_state_dict.pop("LR_Scheduler", None)
677+
if len(optim_state_dict) == 0:
678+
self.optimizer._create_accumulators(
679+
paddle.base.framework.default_main_program().global_block(), self.optimizer._parameter_list
680+
)
681+
optim_state_dict = self.optimizer.state_dict()
682+
optim_state_dict.pop("LR_Scheduler", None)
670683

671-
print("state_dict :", state_dict)
684+
state_dict = {
685+
MODEL_NAME: model_state_dict,
686+
OPTIMIZER_NAME: optim_state_dict,
687+
}
672688

673689
self._load_ckpt_func(state_dict, ckpt_path)
674690

paddlenlp/trainer/trainer.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -748,11 +748,6 @@ def train(
748748
os.makedirs(resume_from_checkpoint, exist_ok=True)
749749
logger.info(f"Reset resume_from_checkpoint to temp directory : {resume_from_checkpoint}")
750750

751-
# memory metrics - must set up as early as possible
752-
self._memory_tracker.start()
753-
if not self.args.should_load_sharding_stage1_model:
754-
self._load_from_checkpoint(resume_from_checkpoint)
755-
756751
train_dataloader = self.get_train_dataloader()
757752

758753
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size
@@ -803,34 +798,43 @@ def train(
803798

804799
self.state = TrainerState()
805800

806-
if self.args.should_load_sharding_stage1_model:
807-
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
808-
809-
elif self.args.should_save_sharding_stage1_model:
810-
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
811-
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
812-
model = self._wrap_model(self.model_wrapped)
813-
if self.sharding_io is not None:
814-
assert delay_optimizer_creation is False, "delay_optimizer_creation should be False"
815-
# the self.optimizer should be wrapped and it is done in _wrap_model
816-
self.sharding_io.set_optimizer(self.optimizer)
817-
# for the rest of this function `model` is the outside model, whether it was wrapped or not
818-
if model is not self.model:
819-
self.model_wrapped = model
820-
if delay_optimizer_creation:
821-
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
822-
self._load_optimizer_and_scheduler(resume_from_checkpoint)
801+
# memory metrics - must set up as early as possible
802+
self._memory_tracker.start()
803+
804+
if not self.args.enable_auto_parallel:
805+
if not self.args.should_load_sharding_stage1_model:
806+
self._load_from_checkpoint(resume_from_checkpoint)
807+
808+
if self.args.should_load_sharding_stage1_model:
809+
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
810+
811+
elif self.args.should_save_sharding_stage1_model:
812+
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
813+
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
814+
model = self._wrap_model(self.model_wrapped)
815+
if self.sharding_io is not None:
816+
assert delay_optimizer_creation is False, "delay_optimizer_creation should be False"
817+
# the self.optimizer should be wrapped and it is done in _wrap_model
818+
self.sharding_io.set_optimizer(self.optimizer)
819+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
820+
if model is not self.model:
821+
self.model_wrapped = model
822+
if delay_optimizer_creation:
823+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
824+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
825+
else:
826+
model = self._wrap_model(self.model_wrapped)
827+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
828+
if model is not self.model:
829+
self.model_wrapped = model
830+
if delay_optimizer_creation:
831+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
832+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
823833
else:
824-
model = self._wrap_model(self.model_wrapped)
825-
# for the rest of this function `model` is the outside model, whether it was wrapped or not
826-
if model is not self.model:
827-
self.model_wrapped = model
828-
if delay_optimizer_creation:
829-
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
830-
self._load_optimizer_and_scheduler(resume_from_checkpoint)
834+
model = self.model_wrapped
835+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
831836

832837
logger.info(f"{self.runtime_timer.log()}")
833-
834838
logger.info("***** Running training *****")
835839
logger.info(f" Num examples = {num_examples:,}")
836840
logger.info(f" Num Epochs = {num_train_epochs}")

0 commit comments

Comments
 (0)