-
Notifications
You must be signed in to change notification settings - Fork 31.3k
FIX(trainer): ensure final checkpoint is saved when resuming training #40347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+121
−24
Merged
Changes from 6 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
fe7edb3
fix(trainer): ensure final checkpoint is saved when resuming training
rangehow ea0ad02
add test
rangehow 37ff988
make style && slight fix of test
rangehow 38a216f
make style again
rangehow 6dc2318
Merge branch 'main' into main
rangehow fb04527
Merge branch 'huggingface:main' into main
rangehow aa8a637
move test code to test_trainer
8712ce1
remove outdated test file
0967dc8
Apply style fixes
github-actions[bot] 04f93a5
Merge branch 'main' into main
SunMarc 4805d48
Merge branch 'main' into main
SunMarc 9826d0c
Merge branch 'main' into main
SunMarc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| import os | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.utils.data import Dataset | ||
|
|
||
| from transformers import ( | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
| from transformers.testing_utils import TestCasePlus | ||
|
|
||
|
|
||
| class DummyModel(nn.Module): | ||
| def __init__(self, input_dim=10, num_labels=2): | ||
| super().__init__() | ||
| self.linear = nn.Linear(input_dim, num_labels) | ||
|
|
||
| def forward(self, input_ids=None, attention_mask=None, labels=None): | ||
| logits = self.linear(input_ids.float()) | ||
| loss = None | ||
| if labels is not None: | ||
| loss_fn = nn.CrossEntropyLoss() | ||
| loss = loss_fn(logits, labels) | ||
| return {"loss": loss, "logits": logits} | ||
|
|
||
|
|
||
| class DummyDictDataset(Dataset): | ||
| def __init__(self, input_ids, attention_mask, labels): | ||
| self.input_ids = input_ids | ||
| self.attention_mask = attention_mask | ||
| self.labels = labels | ||
|
|
||
| def __len__(self): | ||
| return len(self.input_ids) | ||
|
|
||
| def __getitem__(self, idx): | ||
| return { | ||
| "input_ids": self.input_ids[idx], | ||
| "attention_mask": self.attention_mask[idx], | ||
| "labels": self.labels[idx], | ||
| } | ||
|
|
||
|
|
||
| def create_dummy_dataset(): | ||
| """Creates a dummy dataset for testing.""" | ||
| num_samples = 13 | ||
| input_dim = 10 | ||
| dummy_input_ids = torch.rand(num_samples, input_dim) | ||
| dummy_attention_mask = torch.ones(num_samples, input_dim) | ||
| dummy_labels = torch.randint(0, 2, (num_samples,)) | ||
| return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) | ||
|
|
||
|
|
||
| class TestTrainerResume(TestCasePlus): | ||
| def test_resume_with_original_trainer(self): | ||
| """Tests the original transformers Trainer.""" | ||
| print("Testing the original transformers Trainer...") | ||
|
|
||
| # 1. Set up a dummy model | ||
| model = DummyModel(input_dim=10, num_labels=2) | ||
| dummy_dataset = create_dummy_dataset() | ||
|
|
||
| # 3. First training (simulate interruption) | ||
| output_dir_initial = self.get_auto_remove_tmp_dir() | ||
| training_args_initial = TrainingArguments( | ||
| output_dir=output_dir_initial, | ||
| num_train_epochs=1, | ||
| per_device_train_batch_size=2, | ||
| gradient_accumulation_steps=3, | ||
| save_strategy="steps", | ||
| save_steps=1, # Save at every step | ||
| report_to=[], # Disable wandb/tensorboard and other loggers | ||
| max_steps=2, # Stop after step 2 to simulate interruption | ||
| ) | ||
|
|
||
| trainer_initial = Trainer( | ||
| model=model, | ||
| args=training_args_initial, | ||
| train_dataset=dummy_dataset, | ||
| ) | ||
| trainer_initial.train() | ||
|
|
||
| # Make sure we have a checkpoint before interruption | ||
| checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") | ||
| assert os.path.exists(checkpoint_path) | ||
|
|
||
| print("Second phase") | ||
| # 4. Resume training from checkpoint | ||
| output_dir_resumed = self.get_auto_remove_tmp_dir() | ||
| training_args_resumed = TrainingArguments( | ||
| output_dir=output_dir_resumed, | ||
| num_train_epochs=1, | ||
| per_device_train_batch_size=2, | ||
| gradient_accumulation_steps=3, | ||
| save_strategy="steps", | ||
| save_steps=1, # Keep the same save strategy | ||
| ) | ||
|
|
||
| trainer_resumed = Trainer( | ||
| model=model, | ||
| args=training_args_resumed, | ||
| train_dataset=dummy_dataset, | ||
| ) | ||
| # Resume from the interrupted checkpoint and finish the remaining training | ||
| trainer_resumed.train(resume_from_checkpoint=checkpoint_path) | ||
|
|
||
| # 5. Assertion: Check if the final model has been saved | ||
| final_model_path = os.path.join(output_dir_resumed, "checkpoint-3", "model.safetensors") | ||
| try: | ||
| assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!" | ||
| print("✓ Original Trainer: Final model has been saved.") | ||
| except AssertionError as e: | ||
| print(f"✗ Original Trainer: {e}") | ||
|
|
||
|
|
||
| # Run all tests | ||
| if __name__ == "__main__": | ||
| import unittest | ||
|
|
||
| unittest.main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it really test the described issue? here the training will finish without any error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on my understanding, the test should currently be evaluating the correct trainer, so it shouldn't trigger any error messages. Even with an incorrect trainer, this test code won't report errors during training (which is why this bug is quite subtle). The error check is implemented by verifying whether checkpoint-3 is saved at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error manifests as implicitly skipping the training of the last gradient accumulation batch of data, resulting in the failure to save the final checkpoint, rather than causing a crash during training. Please let me know if any part of my description is unclear 😊