Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
is_accelerate_available,
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_exporting,
logging,
)
from .candidate_generator import (
Expand Down Expand Up @@ -502,17 +501,20 @@ def _cache_dependant_input_preparation(
The current implementation does not rely on ``self`` and could be
a class method. It is left as a standard method to be easily rewritten.
"""
if is_torchdynamo_exporting():
return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
# Initial embeddings case first forward pass with only embeddings
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif inputs_embeds is not None and cache_position[0] > 0:
# After first iteration, stop using embeddings and switch to input_ids
inputs_embeds = None
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
# Original logic for other cases
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

return inputs_embeds, input_ids

def _cache_dependant_input_preparation_exporting(
Expand Down
34 changes: 21 additions & 13 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2387,20 +2387,29 @@ def _inner_training_loop(
grad_norm: Optional[float] = None
learning_rate = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

_steps_in_current_epoch = 0
epoch_dataloader = train_dataloader
if len_dataloader is not None:
steps_in_epoch = len_dataloader
else:
steps_in_epoch = args.max_steps * args.gradient_accumulation_steps
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if len_dataloader is None and epoch > epochs_trained and _steps_in_current_epoch > 0:
steps_in_epoch = _steps_in_current_epoch
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
if len_dataloader is None:
_steps_in_current_epoch = 0
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)

steps_in_epoch = (
len(epoch_dataloader)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

step = -1
Expand All @@ -2416,10 +2425,6 @@ def _inner_training_loop(
self._load_rng_state(resume_from_checkpoint)

epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(
remainder < args.gradient_accumulation_steps
Expand Down Expand Up @@ -2552,7 +2557,10 @@ def _inner_training_loop(

model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
if len_dataloader is None:
_steps_in_current_epoch += 1

self.state.epoch = epoch + (step + 1 + steps_trained_in_current_epoch) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss,
Expand Down