@@ -5158,6 +5158,115 @@ def test_trainer_works_without_model_config(self):
51585158 )
51595159 trainer .train ()
51605160
5161+ @require_safetensors
5162+ def test_resume_from_interrupted_training (self ):
5163+ """
5164+ Tests resuming training from a checkpoint after a simulated interruption.
5165+ """
5166+
5167+ # --- Helper classes and functions defined locally for this test ---
5168+ class DummyModel (nn .Module ):
5169+ def __init__ (self , input_dim = 10 , num_labels = 2 ):
5170+ super ().__init__ ()
5171+ self .linear = nn .Linear (input_dim , num_labels )
5172+
5173+ def forward (self , input_ids = None , attention_mask = None , labels = None ):
5174+ logits = self .linear (input_ids .float ())
5175+ loss = None
5176+ if labels is not None :
5177+ loss_fn = nn .CrossEntropyLoss ()
5178+ loss = loss_fn (logits , labels )
5179+ return {"loss" : loss , "logits" : logits }
5180+
5181+ class DummyDictDataset (torch .utils .data .Dataset ):
5182+ def __init__ (self , input_ids , attention_mask , labels ):
5183+ self .input_ids = input_ids
5184+ self .attention_mask = attention_mask
5185+ self .labels = labels
5186+
5187+ def __len__ (self ):
5188+ return len (self .input_ids )
5189+
5190+ def __getitem__ (self , idx ):
5191+ return {
5192+ "input_ids" : self .input_ids [idx ],
5193+ "attention_mask" : self .attention_mask [idx ],
5194+ "labels" : self .labels [idx ],
5195+ }
5196+
5197+ def create_dummy_dataset ():
5198+ """Creates a dummy dataset for this specific test."""
5199+ num_samples = 13
5200+ input_dim = 10
5201+ dummy_input_ids = torch .rand (num_samples , input_dim )
5202+ dummy_attention_mask = torch .ones (num_samples , input_dim )
5203+ dummy_labels = torch .randint (0 , 2 , (num_samples ,))
5204+ return DummyDictDataset (dummy_input_ids , dummy_attention_mask , dummy_labels )
5205+
5206+ # 1. Set up a dummy model and dataset
5207+ model = DummyModel (input_dim = 10 , num_labels = 2 )
5208+ dummy_dataset = create_dummy_dataset ()
5209+
5210+ # 2. First training phase (simulating an interruption)
5211+ output_dir_initial = self .get_auto_remove_tmp_dir ()
5212+ training_args_initial = TrainingArguments (
5213+ output_dir = output_dir_initial ,
5214+ num_train_epochs = 1 ,
5215+ per_device_train_batch_size = 2 ,
5216+ gradient_accumulation_steps = 3 ,
5217+ save_strategy = "steps" ,
5218+ save_steps = 1 , # Save at every step
5219+ report_to = [], # Disable wandb/tensorboard and other loggers
5220+ max_steps = 2 , # Stop after step 2 to simulate interruption
5221+ )
5222+
5223+ trainer_initial = Trainer (
5224+ model = model ,
5225+ args = training_args_initial ,
5226+ train_dataset = dummy_dataset ,
5227+ )
5228+ trainer_initial .train ()
5229+
5230+ # 3. Verify that a checkpoint was created before the "interruption"
5231+ checkpoint_path = os .path .join (output_dir_initial , "checkpoint-2" )
5232+ self .assertTrue (os .path .exists (checkpoint_path ), f"Checkpoint not found at { checkpoint_path } " )
5233+
5234+ # 4. Second training phase (resuming from the checkpoint)
5235+ output_dir_resumed = self .get_auto_remove_tmp_dir ()
5236+ # Note: total steps for one epoch is ceil(13 / (2*3)) = 3.
5237+ # We stopped at step 2, so the resumed training should run for 1 more step.
5238+ training_args_resumed = TrainingArguments (
5239+ output_dir = output_dir_resumed ,
5240+ num_train_epochs = 1 ,
5241+ per_device_train_batch_size = 2 ,
5242+ gradient_accumulation_steps = 3 ,
5243+ save_strategy = "steps" ,
5244+ save_steps = 1 ,
5245+ report_to = [],
5246+ )
5247+
5248+ trainer_resumed = Trainer (
5249+ model = model ,
5250+ args = training_args_resumed ,
5251+ train_dataset = dummy_dataset ,
5252+ )
5253+ # Resume from the interrupted checkpoint and finish the remaining training
5254+ trainer_resumed .train (resume_from_checkpoint = checkpoint_path )
5255+
5256+ # 5. Assertions: Check if the training completed and the final model was saved
5257+ # The training should have completed step 3.
5258+ # Total steps per epoch = ceil(13 samples / (2 batch_size * 3 grad_accum)) = 3
5259+ self .assertEqual (trainer_resumed .state .global_step , 3 )
5260+
5261+ # Check that a checkpoint for the final step exists.
5262+ final_checkpoint_path = os .path .join (output_dir_resumed , "checkpoint-3" )
5263+ self .assertTrue (os .path .exists (final_checkpoint_path ))
5264+
5265+ # Check if the model weights file exists in the final checkpoint directory.
5266+ # Trainer saves non-PreTrainedModel models as `model.safetensors` by default if safetensors is available.
5267+ final_model_path = os .path .join (final_checkpoint_path , SAFE_WEIGHTS_NAME )
5268+ self .assertTrue (os .path .exists (final_model_path ), "Final model checkpoint was not saved!" )
5269+
51615270
51625271@require_torch
51635272@is_staging_test
0 commit comments