Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def prediction_step(
}

if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
_model = model.module if hasattr(model, "module") else model
generated_tokens = _model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
Expand Down
9 changes: 1 addition & 8 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@

from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
require_torch_non_multi_gpu_but_fix_me,
slow,
)
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

Expand Down Expand Up @@ -52,7 +46,6 @@ def test_finetune_trainer_slow(self):
assert "test_results.json" in contents

@slow
@require_torch_non_multi_gpu_but_fix_me
def test_finetune_bert2bert(self):
if not is_datasets_available():
return
Expand Down