3232 check_outdir_not_in_ckptdir ,
3333 copy_files ,
3434 get_adapter_checkpoint_path ,
35+ get_all_checkpoints_in_dir ,
3536 get_model_checkpoint_path ,
3637 get_recipe_checkpoint_path ,
3738 ModelType ,
39+ prune_surplus_checkpoints ,
3840 RECIPE_STATE_DIRNAME ,
3941 REPO_ID_FNAME ,
4042 safe_torch_load ,
@@ -399,6 +401,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
399401 Default is True.
400402 should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
401403 the receipe state from a previous run. Default is False
404+ keep_last_n_checkpoints (Optional[int]): How many checkpoints to keep. If None, all checkpoints are kept.
402405 """
403406
404407 def __init__ (
@@ -412,6 +415,8 @@ def __init__(
412415 resume_from_checkpoint : bool = False ,
413416 safe_serialization : bool = True ,
414417 should_load_recipe_state : bool = False ,
418+ * ,
419+ keep_last_n_checkpoints : Optional [int ] = None ,
415420 ) -> None :
416421
417422 self ._should_load_recipe_state = should_load_recipe_state
@@ -420,6 +425,7 @@ def __init__(
420425 logger .warning (
421426 "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
422427 )
428+ self ._keep_last_n_checkpoints = keep_last_n_checkpoints
423429
424430 self ._safe_serialization = safe_serialization
425431 self ._checkpoint_dir = Path (checkpoint_dir )
@@ -457,7 +463,7 @@ def __init__(
457463 output_dir = self ._output_dir ,
458464 adapter_checkpoint = adapter_checkpoint ,
459465 should_load_recipe_state = self ._should_load_recipe_state ,
460- pattern = r"^epoch_ (\d+)" ,
466+ pattern = r"^step_ (\d+)" ,
461467 )
462468
463469 # resume recipe_state ckpt
@@ -629,6 +635,8 @@ def save_checkpoint(
629635 epoch : int ,
630636 intermediate_checkpoint : bool = False ,
631637 adapter_only : bool = False ,
638+ * ,
639+ step : Optional [int ] = None ,
632640 ) -> None :
633641 """
634642 Save HF checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
@@ -644,10 +652,19 @@ def save_checkpoint(
644652 intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state
645653 and (if applicable) adapter weights are created. Default is False
646654 adapter_only (bool): If True, only save the adapter weights. Default is False
655+ step (Optional[int]): Step number. Used to create the checkpoint file name if provided.
647656
648657 Raises:
649658 ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict.
650659 """
660+ # Prefer to use step, not epoch
661+ if step is not None :
662+ ckpt_save_dirname = f"step_{ step } "
663+ ckpt_pattern = r"^step_(\d+)"
664+ else :
665+ ckpt_save_dirname = f"epoch_{ epoch } "
666+ ckpt_pattern = r"^epoch_(\d+)"
667+
651668 # convert the state_dict back to hf format; do this inplace
652669 if not adapter_only :
653670 if self ._model_type == ModelType .PHI3_MINI :
@@ -747,7 +764,7 @@ def save_checkpoint(
747764 )
748765 map_original_name_to_new_name [cpt_idx ] = shard_name
749766 output_path = Path .joinpath (
750- self ._output_dir , f"epoch_ { epoch } " , shard_name
767+ self ._output_dir , ckpt_save_dirname , shard_name
751768 )
752769 output_path .parent .mkdir (parents = True , exist_ok = True )
753770 if not self ._safe_serialization :
@@ -779,7 +796,7 @@ def save_checkpoint(
779796 index_file_name = TORCH_INDEX_FNAME
780797
781798 index_path = Path .joinpath (
782- self ._output_dir , f"epoch_ { epoch } " , index_file_name
799+ self ._output_dir , ckpt_save_dirname , index_file_name
783800 )
784801
785802 index_data = {
@@ -796,7 +813,7 @@ def save_checkpoint(
796813 # convert_weights.peft_to_tune. The .pt format is not needed, but
797814 # it is an easy way to distinguish the adapters. Ideally we should save only one.
798815 output_path = Path .joinpath (
799- self ._output_dir , f"epoch_ { epoch } " , ADAPTER_MODEL_FNAME
816+ self ._output_dir , ckpt_save_dirname , ADAPTER_MODEL_FNAME
800817 ).with_suffix (".pt" )
801818 output_path .parent .mkdir (parents = True , exist_ok = True )
802819 torch .save (state_dict [training .ADAPTER_KEY ], output_path )
@@ -825,7 +842,7 @@ def save_checkpoint(
825842 head_dim = self ._config .get ("head_dim" , None ),
826843 )
827844 output_path = Path .joinpath (
828- self ._output_dir , f"epoch_ { epoch } " , ADAPTER_MODEL_FNAME
845+ self ._output_dir , ckpt_save_dirname , ADAPTER_MODEL_FNAME
829846 )
830847 output_path .parent .mkdir (parents = True , exist_ok = True )
831848 if not self ._safe_serialization :
@@ -866,7 +883,7 @@ def save_checkpoint(
866883 )
867884
868885 output_path = Path .joinpath (
869- self ._output_dir , f"epoch_ { epoch } " , ADAPTER_CONFIG_FNAME
886+ self ._output_dir , ckpt_save_dirname , ADAPTER_CONFIG_FNAME
870887 ).with_suffix (".json" )
871888 with open (output_path , "w" ) as f :
872889 json .dump (state_dict [training .ADAPTER_CONFIG ], f )
@@ -880,7 +897,7 @@ def save_checkpoint(
880897 # So its easy to run inference with the model using this epoch's checkpoint
881898 copy_files (
882899 self ._checkpoint_dir ,
883- Path .joinpath (self ._output_dir , f"epoch_ { epoch } " ),
900+ Path .joinpath (self ._output_dir , ckpt_save_dirname ),
884901 ignore_suffixes = SUFFIXES_TO_NOT_COPY ,
885902 )
886903
@@ -901,7 +918,7 @@ def save_checkpoint(
901918 f"saved to { output_path } "
902919 )
903920 else :
904- logger .info ("Saving final epoch checkpoint." )
921+ logger .info ("Saving final checkpoint." )
905922 if adapter_only :
906923 logger .info (
907924 "Please note that you have set adapter_only=True, so only adapter weights will be saved."
@@ -914,6 +931,16 @@ def save_checkpoint(
914931 "You can now use this checkpoint for further training or inference."
915932 )
916933
934+ # If specified, prune the checkpoints in the output directory
935+ if self ._keep_last_n_checkpoints is not None :
936+ all_current_checkpoints = get_all_checkpoints_in_dir (
937+ self ._output_dir , pattern = ckpt_pattern
938+ )
939+ prune_surplus_checkpoints (
940+ all_current_checkpoints ,
941+ keep_last_n_checkpoints = self ._keep_last_n_checkpoints ,
942+ )
943+
917944
918945class FullModelMetaCheckpointer (_CheckpointerInterface ):
919946 """
0 commit comments