diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 1b3d52ad4487..8ad1d7a6eeab 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -20,12 +20,6 @@ from transformers.trainer_pt_utils import get_tpu_sampler -try: - from .utils import label_smoothed_nll_loss -except ImportError: - from utils import label_smoothed_nll_loss - - logger = logging.get_logger(__name__) arg_to_scheduler = { @@ -64,6 +58,17 @@ def __init__(self, config=None, data_args=None, *args, **kwargs): f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." ) + if self.args.label_smoothing == 0: + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) + else: + # dynamically import label_smoothed_nll_loss + try: + from .utils import label_smoothed_nll_loss + except ImportError: + from utils import label_smoothed_nll_loss + + self.loss_fn = label_smoothed_nll_loss + def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. @@ -135,9 +140,7 @@ def _compute_loss(self, model, inputs, labels): if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: # force training to ignore pad token logits = model(**inputs, use_cache=False)[0] - - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) - loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: # compute usual loss via models loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] @@ -145,9 +148,7 @@ def _compute_loss(self, model, inputs, labels): # compute label smoothed loss logits = model(**inputs, use_cache=False)[0] lprobs = torch.nn.functional.log_softmax(logits, dim=-1) - loss, _ = label_smoothed_nll_loss( - lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id - ) + loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) return loss, logits def compute_loss(self, model, inputs):