@@ -140,21 +140,25 @@ def __init__(self, args):
140140 self ._process_master_weight = None
141141 self ._process_optimizer_weight = None
142142 self ._lock = None
143- self ._shared_save_path = None
144143 self ._shared_save_model_flag = None
145144 self ._shared_save_master_weight_flag = None
146145 self ._shared_save_optimizer_flag = None
147146
148147 if "async_save" in self .args .unified_checkpoint_config :
149148 self ._lock = multiprocessing .Lock ()
150149 self ._shared_save_model_path = multiprocessing .Array ("c" , 100000 )
150+ self ._shared_save_model_signal_path = multiprocessing .Array ("c" , 100000 )
151151 self ._shared_save_master_weight_path = multiprocessing .Array ("c" , 100000 )
152+ self ._shared_save_master_weight_signal_path = multiprocessing .Array ("c" , 100000 )
152153 self ._shared_save_optimizer_path = multiprocessing .Array ("c" , 100000 )
154+ self ._shared_save_optimizer_signal_path = multiprocessing .Array ("c" , 100000 )
153155 self ._shared_save_model_flag = multiprocessing .Array ("i" , 1 )
154156 self ._shared_save_master_weight_flag = multiprocessing .Array ("i" , 1 )
155157 self ._shared_save_optimizer_flag = multiprocessing .Array ("i" , 1 )
156158
157- def _file_save_async_or_sync (self , state_dict , path , is_sync = True , state_dict_type = "model_weight" ):
159+ def _file_save_async_or_sync (
160+ self , state_dict , path , signal_path = None , is_sync = True , state_dict_type = "model_weight"
161+ ):
158162 if is_sync :
159163 for k in list (state_dict .keys ()):
160164 if isinstance (state_dict [k ], paddle .Tensor ):
@@ -169,6 +173,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
169173 meta_dict = self ._meta_dict_model
170174 shared_save_flag = self ._shared_save_model_flag
171175 shared_save_path = self ._shared_save_model_path
176+ shared_save_signal_path = self ._shared_save_model_signal_path
172177 if self ._process_model_weight is None :
173178 self ._process_model_weight = multiprocessing .Process (
174179 target = self ._save_file_async_in_process ,
@@ -177,6 +182,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
177182 self ._shm_model_weight .name ,
178183 self ._shared_save_model_flag ,
179184 self ._shared_save_model_path ,
185+ self ._shared_save_model_signal_path ,
180186 self ._lock ,
181187 state_dict_type ,
182188 self .global_rank ,
@@ -191,6 +197,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
191197 meta_dict = self ._meta_dict_master_weight
192198 shared_save_flag = self ._shared_save_master_weight_flag
193199 shared_save_path = self ._shared_save_master_weight_path
200+ shared_save_signal_path = self ._shared_save_master_weight_signal_path
194201 if self ._process_master_weight is None :
195202 self ._process_master_weight = multiprocessing .Process (
196203 target = self ._save_file_async_in_process ,
@@ -199,6 +206,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
199206 self ._shm_master_weight .name ,
200207 self ._shared_save_master_weight_flag ,
201208 self ._shared_save_master_weight_path ,
209+ self ._shared_save_master_weight_signal_path ,
202210 self ._lock ,
203211 "model_weight"
204212 if "skip_save_model_weight" in self .args .unified_checkpoint_config
@@ -215,6 +223,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
215223 meta_dict = self ._meta_dict_optim
216224 shared_save_flag = self ._shared_save_optimizer_flag
217225 shared_save_path = self ._shared_save_optimizer_path
226+ shared_save_signal_path = self ._shared_save_optimizer_signal_path
218227 if self ._process_optimizer_weight is None :
219228 self ._process_optimizer_weight = multiprocessing .Process (
220229 target = self ._save_file_async_in_process ,
@@ -223,6 +232,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
223232 self ._shm_optimizer_weight .name ,
224233 self ._shared_save_optimizer_flag ,
225234 self ._shared_save_optimizer_path ,
235+ self ._shared_save_optimizer_signal_path ,
226236 self ._lock ,
227237 state_dict_type ,
228238 self .global_rank ,
@@ -238,6 +248,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
238248 logger .info (f"Wait for the previous save process to finish saving { state_dict_type } " )
239249 # only save model weight or save master weight, we enter this loop.
240250 self ._reset_and_update (shared_save_path , path )
251+ self ._reset_and_update (shared_save_signal_path , signal_path )
241252 _traverse_copy_to_shm (state_dict , meta_dict , shm_state_dict .buf )
242253 with self ._lock :
243254 shared_save_flag [0 ] = 1
@@ -248,6 +259,7 @@ def _save_file_async_in_process(
248259 shm_name ,
249260 shared_save_flag ,
250261 shared_save_path ,
262+ shared_save_signal_path ,
251263 lock ,
252264 state_dict_type ,
253265 global_rank ,
@@ -261,11 +273,13 @@ def _save_file_async_in_process(
261273 continue
262274 if flag_value == 1 : # need to save
263275 path = shared_save_path [:].decode ("utf-8" ).rstrip ("\x00 " )
276+ signal_path = shared_save_signal_path [:].decode ("utf-8" ).rstrip ("\x00 " )
264277 logger .info (f"Start to async save { path } " )
265278 state_dict = _read_state_dict_from_shm (meta_dict , shm ) # numpy array
266279 safe_save_file (state_dict , path , {"format" : "np" })
267280 del state_dict
268- saved_signal_path = os .path .join (os .path .dirname (path ), f".{ state_dict_type } .done.{ global_rank } " )
281+ os .makedirs (signal_path , exist_ok = True )
282+ saved_signal_path = os .path .join (signal_path , f".{ state_dict_type } .done.{ global_rank } " )
269283 paddle .save (global_rank , saved_signal_path )
270284 with lock :
271285 shared_save_flag [0 ] = 0
@@ -280,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
280294 encoded_value = new_value .encode ("utf-8" )
281295 shared_array [: len (encoded_value )] = encoded_value
282296
283- def save_unified_checkpoint (self , model , optimizer , output_dir ):
297+ def save_unified_checkpoint (self , model , optimizer , output_dir , signal_dir = None ):
284298 """save unified checkpoint
285299
286300 Args:
@@ -317,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
317331
318332 save_directory = output_dir
319333 os .makedirs (save_directory , exist_ok = True )
334+ if signal_dir is not None :
335+ os .makedirs (signal_dir , exist_ok = True ) # only for async save
320336
321337 # save model weights
322338 if not skip_save_model_weight :
@@ -329,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
329345 self ._file_save_async_or_sync (
330346 state_dict ,
331347 path = os .path .join (save_directory , shard_file ),
348+ signal_path = signal_dir ,
332349 is_sync = is_sync_save ,
333350 state_dict_type = "model_weight" ,
334351 )
@@ -397,7 +414,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
397414 if self .args .dataset_rank == 0 or self .args .use_expert_parallel :
398415 load_unified_checkpoint_locally (self .args , model , resume_from_checkpoint , safe_serialization = True )
399416
400- def save_non_merge_optimizer (self , model , optimizer , output_dir ):
417+ def save_non_merge_optimizer (self , model , optimizer , output_dir , signal_dir ):
401418 paddle .device .cuda .empty_cache ()
402419 optim_state_dict = nested_copy (optimizer .state_dict ())
403420 master_weights = None
@@ -456,12 +473,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
456473 self ._file_save_async_or_sync (
457474 optim_state_dict ,
458475 path = os .path .join (output_dir , optimizer_name ),
476+ signal_path = signal_dir ,
459477 is_sync = is_sync_save ,
460478 state_dict_type = "optimizer_weight" ,
461479 )
462480 self ._file_save_async_or_sync (
463481 master_weights ,
464482 path = os .path .join (output_dir , master_weights_name ),
483+ signal_path = signal_dir ,
465484 is_sync = is_sync_save ,
466485 state_dict_type = "master_weight" ,
467486 )
@@ -511,22 +530,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
511530
512531 return returned_optim_state_dict
513532
514- def save_unified_optimizer (self , model , optimizer , output_dir ):
533+ def save_unified_optimizer (self , model , optimizer , output_dir , signal_dir ):
515534 """save unified optimizer
516535
517536 Args:
518537 model (PretrainedModel): model used to get key mapping.
519538 optimizer (Optimizer): optimizer to save
520539 output_dir (str): Save directory.
540+ signal_dir (str): Asynchronous saving signal directory.
521541
522542 """
523543
524544 if "ignore_merge_optimizer" in self .args .unified_checkpoint_config :
525- self .save_non_merge_optimizer (model , optimizer , output_dir )
545+ self .save_non_merge_optimizer (model , optimizer , output_dir , signal_dir )
526546 return
527547
528548 if paddle .distributed .get_world_size () <= 1 :
529- self .save_single_card_optimizer (model , optimizer , output_dir )
549+ self .save_single_card_optimizer (model , optimizer , output_dir ) # no need to save signal
530550 return
531551
532552 # Split into naive optimizer params and master weights.
@@ -542,20 +562,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
542562
543563 save_directory = output_dir
544564 os .makedirs (save_directory , exist_ok = True )
565+ if signal_dir is not None :
566+ os .makedirs (signal_dir , exist_ok = True )
545567
546568 is_sync_save = True
547569 if "async_save" in self .args .unified_checkpoint_config :
548570 is_sync_save = False
549571 self ._file_save_async_or_sync (
550572 optim_state_dict ,
551573 path = os .path .join (save_directory , shard_optim_file ),
574+ signal_path = signal_dir ,
552575 is_sync = is_sync_save ,
553576 state_dict_type = "optimizer_weight" ,
554577 )
555578 if master_weight_state_dict is not None :
556579 self ._file_save_async_or_sync (
557580 master_weight_state_dict ,
558581 path = os .path .join (save_directory , shard_master_weight_file ),
582+ signal_path = signal_dir ,
559583 is_sync = is_sync_save ,
560584 state_dict_type = "master_weight" ,
561585 )
0 commit comments