Skip to content
36 changes: 14 additions & 22 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2584,18 +2584,19 @@ def _inner_training_loop(
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

step = -1
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True

step = -1
# Handle resumption from checkpoint
if epoch == epochs_trained and resume_from_checkpoint is not None:
if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
step = steps_trained_in_current_epoch - 1
rng_to_sync = True
elif steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)


epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = steps_in_epoch % args.gradient_accumulation_steps
Expand Down Expand Up @@ -2630,21 +2631,12 @@ def _inner_training_loop(
input_tokens = inputs[main_input_name].numel()
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()

if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False

# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None


if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Expand Down Expand Up @@ -2737,7 +2729,7 @@ def _inner_training_loop(

model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss,
Expand Down
120 changes: 120 additions & 0 deletions tests/trainer/test_trainer_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import shutil
import torch
from torch.utils.data import TensorDataset, Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
import torch.nn as nn


class DummyModel(nn.Module):
def __init__(self, input_dim=10, num_labels=2):
super().__init__()
self.linear = nn.Linear(input_dim, num_labels)

def forward(self, input_ids=None, attention_mask=None, labels=None):
logits = self.linear(input_ids.float())
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}

class DummyDictDataset(Dataset):
def __init__(self, input_ids, attention_mask, labels):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels

def __len__(self):
return len(self.input_ids)

def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.labels[idx],
}

def create_dummy_dataset():
"""Creates a dummy dataset for testing."""
num_samples = 13
input_dim = 10
dummy_input_ids = torch.rand(num_samples, input_dim)
dummy_attention_mask = torch.ones(num_samples, input_dim)
dummy_labels = torch.randint(0, 2, (num_samples,))
return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels)

def test_resume_with_original_trainer():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put that in test_trainer.py under TrainerIntegrationTest class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you forgot to push the commit ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code you're looking at seems to be outdated. I've already pushed a new version of the test code, so you might be commenting on the previous one.
In my new code, the test class now inherits from TestCasePlus, because this parent class already includes the logic for finding and cleaning up resources automatically.
I'm not entirely sure what else is needed. From my perspective, the code should be complete now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see commit 37ff988

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean to put your test code here

class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):

It will be better to not create a new file when introducing new tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh. I see. Thanks for explanation. :hug

"""Tests the original transformers Trainer."""
print("Testing the original transformers Trainer...")

# 1. Set up a dummy model
model = DummyModel(input_dim=10, num_labels=2)
dummy_dataset = create_dummy_dataset()

# 3. First training (simulate interruption)
output_dir_initial = "./test_original_trainer_initial"
training_args_initial = TrainingArguments(
output_dir=output_dir_initial,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=3,
save_strategy="steps",
save_steps=1, # Save at every step
report_to=[], # Disable wandb/tensorboard and other loggers
max_steps=2, # Stop after step 2 to simulate interruption
)

trainer_initial = Trainer(
model=model,
args=training_args_initial,
train_dataset=dummy_dataset,
)
trainer_initial.train()

# Make sure we have a checkpoint before interruption
checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2")
assert os.path.exists(checkpoint_path)

print("Second phase")
# 4. Resume training from checkpoint
output_dir_resumed = "./test_original_trainer_resumed"
training_args_resumed = TrainingArguments(
output_dir=output_dir_resumed,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=3,
save_strategy="steps",
save_steps=1, # Keep the same save strategy
)

trainer_resumed = Trainer(
model=model,
args=training_args_resumed,
train_dataset=dummy_dataset,
)
# Resume from the interrupted checkpoint and finish the remaining training
trainer_resumed.train(resume_from_checkpoint=checkpoint_path)

# 5. Assertion: Check if the final model has been saved
final_model_path = os.path.join(output_dir_resumed,'checkpoint-3', "model.safetensors")
try:
assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!"
print("✓ Original Trainer: Final model has been saved.")
except AssertionError as e:
print(f"✗ Original Trainer: {e}")


# Clean up test directories
shutil.rmtree(output_dir_initial)
shutil.rmtree(output_dir_resumed)


# Run all tests
if __name__ == "__main__":
test_resume_with_original_trainer()