Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
from transformers.trainer_pt_utils import get_tpu_sampler


try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss


logger = logging.get_logger(__name__)

arg_to_scheduler = {
Expand Down Expand Up @@ -64,6 +58,17 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
)

if self.args.label_smoothing == 0:
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We moved all seq2seq scripts/tests to not need the try/except workaround, and just import from utils, see:
https://github.com/huggingface/transformers/pull/7274/files#diff-36ebee2221416f906bccc099a8958f6814f51a1dc8f2b07a75ffd19bd0d1ddf6
this script looks like a new file that was copied from some old style code.

Of course, your PR didn't introduce it, but just as well let's use the clean new style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know! -> will open a PR to correct it :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinged you on the PR :-) #8254

from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss

self.loss_fn = label_smoothed_nll_loss

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
Expand Down Expand Up @@ -135,19 +140,15 @@ def _compute_loss(self, model, inputs, labels):
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
logits = model(**inputs, use_cache=False)[0]

loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
else:
# compute label smoothed loss
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
return loss, logits

def compute_loss(self, model, inputs):
Expand Down