@@ -122,12 +122,10 @@ def _wrap_for_auto(self, model, train_dataloader):
122122 if self .args .to_static :
123123 unified_strategy = dist .Strategy ()
124124 unified_strategy ._from_legacy_strategy (self .args .strategy )
125- return (
126- dist .to_static (model , dist_loader , self .criterion , self .optimizer , strategy = unified_strategy ),
127- dist_loader ,
128- )
129- else :
130- return model , dist_loader
125+ model = dist .to_static (model , dist_loader , self .criterion , self .optimizer , strategy = unified_strategy )
126+
127+ self .model_wrapped = model
128+ return model , dist_loader
131129
132130 def _wrap_amp_model (self , args , model ):
133131 logger .info ("Using half precision" )
@@ -216,7 +214,6 @@ def _inner_training_loop(
216214 epochs_trained = self .state .global_step // num_update_steps_per_epoch
217215 if not args .ignore_data_skip :
218216 steps_trained_in_current_epoch = self .state .global_step % (num_update_steps_per_epoch )
219- steps_trained_in_current_epoch *= args .gradient_accumulation_steps
220217 else :
221218 steps_trained_in_current_epoch = 0
222219
@@ -269,6 +266,9 @@ def _inner_training_loop(
269266 model , dist_loader = self ._wrap_for_auto (model , train_dataloader )
270267 train_dataloader = dist_loader ()
271268
269+ if resume_from_checkpoint is not None :
270+ self ._load_from_checkpoint (resume_from_checkpoint )
271+
272272 self .timers and self .timers ("read-data" ).start ()
273273
274274 for epoch in range (epochs_trained , num_train_epochs ):
@@ -542,14 +542,26 @@ def _save_checkpoint(self, model, metrics=None):
542542 logger .info (f"Saving checkpoinit files into { output_dir } " )
543543
544544 if self .args .should_save_model_state :
545-
546- optim_state_dict = self .optimizer .state_dict ()
547- optim_state_dict .pop ("LR_Scheduler" , None )
548-
549- state_dict = {
550- MODEL_NAME : self .model .state_dict (),
551- OPTIMIZER_NAME : optim_state_dict ,
552- }
545+ if self .args .to_static :
546+ state_dict = model .state_dict ()
547+ else :
548+ optim_state_dict = self .optimizer .state_dict ()
549+ optim_state_dict .pop ("LR_Scheduler" , None )
550+ opt_state_keys = ["_moment1_0" , "_moment2_0" , "_beta1_pow_acc_0" , "_beta2_pow_acc_0" ]
551+ for p_name , p in model .state_dict ().items ():
552+ if paddle .distributed .get_rank () not in p .process_mesh .process_ids :
553+ var_name = p .name
554+ for key in opt_state_keys :
555+ if (
556+ var_name + key in optim_state_dict
557+ and not optim_state_dict [var_name + key ].is_dist ()
558+ ):
559+ optim_state_dict .pop (var_name + key )
560+
561+ state_dict = {
562+ MODEL_NAME : model .state_dict (),
563+ OPTIMIZER_NAME : optim_state_dict ,
564+ }
553565
554566 self ._save_ckpt_func (state_dict , os .path .join (output_dir , DIST_CKPT_PATH ))
555567 logger .info (f"Model weights and optimizer states saved in { output_dir } /{ DIST_CKPT_PATH } " )
@@ -584,13 +596,9 @@ def _save_checkpoint(self, model, metrics=None):
584596 rng_states = {
585597 "python" : random .getstate (),
586598 "numpy" : np .random .get_state (),
587- "cuda" : [ k . current_seed () for k in paddle .get_rng_state ()] ,
588- "cpu" : paddle .framework .core .default_cpu_generator ().get_state (). current_seed () ,
599+ "cuda" : paddle .get_rng_state (),
600+ "cpu" : paddle .framework .core .default_cpu_generator ().get_state (),
589601 }
590- # if self.args.use_hybrid_parallel:
591- # rng_states[
592- # "hybrid_parallel_rng_state_tracker"
593- # ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
594602
595603 if self .args .world_size > 1 :
596604 rng_states_list = []
@@ -660,15 +668,23 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
660668 if not os .path .isdir (ckpt_path ):
661669 raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } " )
662670
663- optim_state_dict = self .optimizer .state_dict ()
664- optim_state_dict .pop ("LR_Scheduler" , None )
665-
666- state_dict = {
667- MODEL_NAME : self .model .state_dict (),
668- OPTIMIZER_NAME : optim_state_dict ,
669- }
671+ if self .args .to_static :
672+ state_dict = self .model_wrapped .state_dict ()
673+ else :
674+ model_state_dict = self .model_wrapped .state_dict ()
675+ optim_state_dict = self .optimizer .state_dict ()
676+ optim_state_dict .pop ("LR_Scheduler" , None )
677+ if len (optim_state_dict ) == 0 :
678+ self .optimizer ._create_accumulators (
679+ paddle .base .framework .default_main_program ().global_block (), self .optimizer ._parameter_list
680+ )
681+ optim_state_dict = self .optimizer .state_dict ()
682+ optim_state_dict .pop ("LR_Scheduler" , None )
670683
671- print ("state_dict :" , state_dict )
684+ state_dict = {
685+ MODEL_NAME : model_state_dict ,
686+ OPTIMIZER_NAME : optim_state_dict ,
687+ }
672688
673689 self ._load_ckpt_func (state_dict , ckpt_path )
674690
0 commit comments