Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,15 +2426,28 @@ def _inner_training_loop(
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
if num_update_steps_per_epoch is not None:
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
else:
steps_trained_in_current_epoch = 0
# If the dataloader does not have a length, we cannot restore the number of trained epochs.
# In the following loop, we repeatedly iterate over the dataloader to skip the first
# `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly.
epochs_trained = 0
steps_trained_in_current_epoch = self.state.global_step * args.gradient_accumulation_steps
if args.ignore_data_skip:
raise ValueError(
"The dataloader does not have a length, so it is impossible to restore the number of trained"
" epochs. Please disable the `ignore_data_skip` option."
)

logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
if num_update_steps_per_epoch is not None:
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
Expand Down Expand Up @@ -2467,6 +2480,32 @@ def _inner_training_loop(
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)

steps_skipped = 0
rng_to_sync = False
epoch_iterator = None
if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None:
# Since the dataloader does not have a length, we just loop until the required number of steps.
# Every time we reach the end of the dataloader, we increment epoch and reset the iterator.
epoch_iterator = iter(epoch_dataloader)
epoch_over = False
while steps_trained_in_current_epoch > 0:
try:
# If the dataloader yields N batches and N is not divisible by `args.gradient_accumulation_steps`,
# the update loop ignores the last `N % args.gradient_accumulation_steps` batches of an epoch.
# To replicate the same behavior when resuming training, we ignore such batches from skipped epochs.
for _ in range(args.gradient_accumulation_steps):
next(epoch_iterator)
steps_trained_in_current_epoch -= args.gradient_accumulation_steps
steps_skipped += args.gradient_accumulation_steps
except StopIteration:
epoch_over = True
break
if epoch_over:
epochs_trained += 1
continue
assert steps_trained_in_current_epoch == 0
rng_to_sync = True

# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
Expand All @@ -2481,16 +2520,15 @@ def _inner_training_loop(
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True

step = -1
epoch_iterator = iter(epoch_dataloader)
if epoch_iterator is None:
epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
Expand Down Expand Up @@ -2648,13 +2686,6 @@ def _inner_training_loop(
if is_torch_xla_available():
xm.mark_step()
break
if step < 0:
logger.warning(
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
Expand Down Expand Up @@ -5385,7 +5416,7 @@ def set_initial_training_values(
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_update_steps_per_epoch = None
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3421,6 +3421,44 @@ def test_resume_training_with_frozen_params(self):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

@parameterized.expand([(9, 1), (10, 1), (11, 1), (20, 1), (21, 1), (9, 3), (9, 2)])
def test_resume_training_with_iterable_dataset(self, dataset_length, gradient_accumulation_steps):
with tempfile.TemporaryDirectory() as tmpdir:

def get_trainer():
config = RegressionModelConfig()
train_dataset = SampleIterableDataset(length=dataset_length)
model = RegressionRandomPreTrainedModel(config)
args = RegressionTrainingArguments(
output_dir=tmpdir,
learning_rate=0.1,
max_steps=20,
save_steps=10,
per_device_train_batch_size=1,
gradient_accumulation_steps=gradient_accumulation_steps,
)
return Trainer(model=model, args=args, train_dataset=train_dataset)

# Train from scratch.
trainer = get_trainer()
trainer.train()
self.assertEqual(trainer.state.global_step, 20)
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)

# Train from a checkpoint.
checkpoint = os.path.join(tmpdir, "checkpoint-10")
trainer = get_trainer()
trainer.train(resume_from_checkpoint=checkpoint)
self.assertEqual(trainer.state.global_step, 20)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)

# Check that the resumed model is the same as the original one.
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down