-
Notifications
You must be signed in to change notification settings - Fork 31.3k
FIX(trainer): ensure final checkpoint is saved when resuming training #40347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
fe7edb3
fix(trainer): ensure final checkpoint is saved when resuming training
rangehow ea0ad02
add test
rangehow 37ff988
make style && slight fix of test
rangehow 38a216f
make style again
rangehow 6dc2318
Merge branch 'main' into main
rangehow fb04527
Merge branch 'huggingface:main' into main
rangehow aa8a637
move test code to test_trainer
8712ce1
remove outdated test file
0967dc8
Apply style fixes
github-actions[bot] 04f93a5
Merge branch 'main' into main
SunMarc 4805d48
Merge branch 'main' into main
SunMarc 9826d0c
Merge branch 'main' into main
SunMarc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| import os | ||
| import shutil | ||
| import torch | ||
| from torch.utils.data import TensorDataset, Dataset | ||
| from transformers import ( | ||
| AutoModelForSequenceClassification, | ||
| AutoTokenizer, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class DummyModel(nn.Module): | ||
| def __init__(self, input_dim=10, num_labels=2): | ||
| super().__init__() | ||
| self.linear = nn.Linear(input_dim, num_labels) | ||
|
|
||
| def forward(self, input_ids=None, attention_mask=None, labels=None): | ||
| logits = self.linear(input_ids.float()) | ||
| loss = None | ||
| if labels is not None: | ||
| loss_fn = nn.CrossEntropyLoss() | ||
| loss = loss_fn(logits, labels) | ||
| return {"loss": loss, "logits": logits} | ||
|
|
||
| class DummyDictDataset(Dataset): | ||
| def __init__(self, input_ids, attention_mask, labels): | ||
| self.input_ids = input_ids | ||
| self.attention_mask = attention_mask | ||
| self.labels = labels | ||
|
|
||
| def __len__(self): | ||
| return len(self.input_ids) | ||
|
|
||
| def __getitem__(self, idx): | ||
| return { | ||
| "input_ids": self.input_ids[idx], | ||
| "attention_mask": self.attention_mask[idx], | ||
| "labels": self.labels[idx], | ||
| } | ||
|
|
||
| def create_dummy_dataset(): | ||
| """Creates a dummy dataset for testing.""" | ||
| num_samples = 13 | ||
| input_dim = 10 | ||
| dummy_input_ids = torch.rand(num_samples, input_dim) | ||
| dummy_attention_mask = torch.ones(num_samples, input_dim) | ||
| dummy_labels = torch.randint(0, 2, (num_samples,)) | ||
| return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) | ||
|
|
||
| def test_resume_with_original_trainer(): | ||
| """Tests the original transformers Trainer.""" | ||
| print("Testing the original transformers Trainer...") | ||
|
|
||
| # 1. Set up a dummy model | ||
| model = DummyModel(input_dim=10, num_labels=2) | ||
| dummy_dataset = create_dummy_dataset() | ||
|
|
||
| # 3. First training (simulate interruption) | ||
| output_dir_initial = "./test_original_trainer_initial" | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| training_args_initial = TrainingArguments( | ||
| output_dir=output_dir_initial, | ||
| num_train_epochs=1, | ||
| per_device_train_batch_size=2, | ||
| gradient_accumulation_steps=3, | ||
| save_strategy="steps", | ||
| save_steps=1, # Save at every step | ||
| report_to=[], # Disable wandb/tensorboard and other loggers | ||
| max_steps=2, # Stop after step 2 to simulate interruption | ||
| ) | ||
|
|
||
| trainer_initial = Trainer( | ||
| model=model, | ||
| args=training_args_initial, | ||
| train_dataset=dummy_dataset, | ||
| ) | ||
| trainer_initial.train() | ||
|
|
||
| # Make sure we have a checkpoint before interruption | ||
| checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") | ||
| assert os.path.exists(checkpoint_path) | ||
|
|
||
| print("Second phase") | ||
| # 4. Resume training from checkpoint | ||
| output_dir_resumed = "./test_original_trainer_resumed" | ||
| training_args_resumed = TrainingArguments( | ||
| output_dir=output_dir_resumed, | ||
| num_train_epochs=1, | ||
| per_device_train_batch_size=2, | ||
| gradient_accumulation_steps=3, | ||
| save_strategy="steps", | ||
| save_steps=1, # Keep the same save strategy | ||
| ) | ||
|
|
||
| trainer_resumed = Trainer( | ||
| model=model, | ||
| args=training_args_resumed, | ||
| train_dataset=dummy_dataset, | ||
| ) | ||
| # Resume from the interrupted checkpoint and finish the remaining training | ||
| trainer_resumed.train(resume_from_checkpoint=checkpoint_path) | ||
|
|
||
| # 5. Assertion: Check if the final model has been saved | ||
| final_model_path = os.path.join(output_dir_resumed,'checkpoint-3', "model.safetensors") | ||
| try: | ||
| assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!" | ||
| print("✓ Original Trainer: Final model has been saved.") | ||
| except AssertionError as e: | ||
| print(f"✗ Original Trainer: {e}") | ||
|
|
||
|
|
||
| # Clean up test directories | ||
| shutil.rmtree(output_dir_initial) | ||
| shutil.rmtree(output_dir_resumed) | ||
|
|
||
|
|
||
| # Run all tests | ||
| if __name__ == "__main__": | ||
| test_resume_with_original_trainer() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put that in test_trainer.py under TrainerIntegrationTest class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you forgot to push the commit ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code you're looking at seems to be outdated. I've already pushed a new version of the test code, so you might be commenting on the previous one.
In my new code, the test class now inherits from TestCasePlus, because this parent class already includes the logic for finding and cleaning up resources automatically.
I'm not entirely sure what else is needed. From my perspective, the code should be complete now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see commit 37ff988
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean to put your test code here
transformers/tests/trainer/test_trainer.py
Line 1304 in 263d06f
It will be better to not create a new file when introducing new tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh. I see. Thanks for explanation. :hug