Skip to content
36 changes: 12 additions & 24 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,7 +2535,6 @@ def _inner_training_loop(
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None

# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
Expand Down Expand Up @@ -2596,18 +2595,18 @@ def _inner_training_loop(
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

step = -1
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True

step = -1
# Handle resumption from checkpoint
if epoch == epochs_trained and resume_from_checkpoint is not None:
if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
step = steps_trained_in_current_epoch - 1
rng_to_sync = True
elif steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = steps_in_epoch % args.gradient_accumulation_steps
Expand Down Expand Up @@ -2642,22 +2641,11 @@ def _inner_training_loop(
input_tokens = inputs[main_input_name].numel()
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()

if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False

# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

Expand Down Expand Up @@ -2749,7 +2737,7 @@ def _inner_training_loop(

model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss,
Expand Down
121 changes: 121 additions & 0 deletions tests/trainer/test_trainer_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from transformers import (
Trainer,
TrainingArguments,
)
from transformers.testing_utils import TestCasePlus


class DummyModel(nn.Module):
def __init__(self, input_dim=10, num_labels=2):
super().__init__()
self.linear = nn.Linear(input_dim, num_labels)

def forward(self, input_ids=None, attention_mask=None, labels=None):
logits = self.linear(input_ids.float())
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}


class DummyDictDataset(Dataset):
def __init__(self, input_ids, attention_mask, labels):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels

def __len__(self):
return len(self.input_ids)

def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.labels[idx],
}


def create_dummy_dataset():
"""Creates a dummy dataset for testing."""
num_samples = 13
input_dim = 10
dummy_input_ids = torch.rand(num_samples, input_dim)
dummy_attention_mask = torch.ones(num_samples, input_dim)
dummy_labels = torch.randint(0, 2, (num_samples,))
return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels)


class TestTrainerResume(TestCasePlus):
def test_resume_with_original_trainer(self):
"""Tests the original transformers Trainer."""
print("Testing the original transformers Trainer...")

# 1. Set up a dummy model
model = DummyModel(input_dim=10, num_labels=2)
dummy_dataset = create_dummy_dataset()

# 3. First training (simulate interruption)
output_dir_initial = self.get_auto_remove_tmp_dir()
training_args_initial = TrainingArguments(
output_dir=output_dir_initial,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=3,
save_strategy="steps",
save_steps=1, # Save at every step
report_to=[], # Disable wandb/tensorboard and other loggers
max_steps=2, # Stop after step 2 to simulate interruption
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it really test the described issue? here the training will finish without any error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my understanding, the test should currently be evaluating the correct trainer, so it shouldn't trigger any error messages. Even with an incorrect trainer, this test code won't report errors during training (which is why this bug is quite subtle). The error check is implemented by verifying whether checkpoint-3 is saved at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error manifests as implicitly skipping the training of the last gradient accumulation batch of data, resulting in the failure to save the final checkpoint, rather than causing a crash during training. Please let me know if any part of my description is unclear 😊

)

trainer_initial = Trainer(
model=model,
args=training_args_initial,
train_dataset=dummy_dataset,
)
trainer_initial.train()

# Make sure we have a checkpoint before interruption
checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2")
assert os.path.exists(checkpoint_path)

print("Second phase")
# 4. Resume training from checkpoint
output_dir_resumed = self.get_auto_remove_tmp_dir()
training_args_resumed = TrainingArguments(
output_dir=output_dir_resumed,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=3,
save_strategy="steps",
save_steps=1, # Keep the same save strategy
)

trainer_resumed = Trainer(
model=model,
args=training_args_resumed,
train_dataset=dummy_dataset,
)
# Resume from the interrupted checkpoint and finish the remaining training
trainer_resumed.train(resume_from_checkpoint=checkpoint_path)

# 5. Assertion: Check if the final model has been saved
final_model_path = os.path.join(output_dir_resumed, "checkpoint-3", "model.safetensors")
try:
assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!"
print("✓ Original Trainer: Final model has been saved.")
except AssertionError as e:
print(f"✗ Original Trainer: {e}")


# Run all tests
if __name__ == "__main__":
import unittest

unittest.main()