diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 2859c2186a2f..d4a81f0c84d5 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -81,6 +81,13 @@ AutoModelForMultipleChoice :members: +AutoModelForNextSentencePrediction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AutoModelForNextSentencePrediction + :members: + + AutoModelForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/lxmert/modeling_frcnn.py b/examples/lxmert/modeling_frcnn.py index 40b0e4bbfb40..a86f68801eff 100644 --- a/examples/lxmert/modeling_frcnn.py +++ b/examples/lxmert/modeling_frcnn.py @@ -1801,7 +1801,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) diff --git a/model_cards/neuralmind/bert-base-portuguese-cased/README.md b/model_cards/neuralmind/bert-base-portuguese-cased/README.md index 375f4268711e..85deb52e3618 100644 --- a/model_cards/neuralmind/bert-base-portuguese-cased/README.md +++ b/model_cards/neuralmind/bert-base-portuguese-cased/README.md @@ -29,7 +29,7 @@ For further information or requests, please go to [BERTimbau repository](https:/ ```python from transformers import AutoTokenizer # Or BertTokenizer -from transformers import AutoModelForPretraining # Or BertForPreTraining for loading pretraining heads +from transformers import AutoModelForPreTraining # Or BertForPreTraining for loading pretraining heads from transformers import AutoModel # or BertModel, for BERT without pretraining heads model = AutoModelForPreTraining.from_pretrained('neuralmind/bert-base-portuguese-cased') diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d8279e8c5305..fbb6789d4ebc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -329,6 +329,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, @@ -340,6 +341,7 @@ AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForMultipleChoice, + AutoModelForNextSentencePrediction, AutoModelForPreTraining, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 3ec971325075..03212b17b185 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -77,6 +77,7 @@ from .modeling_bert import ( BertForMaskedLM, BertForMultipleChoice, + BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, @@ -128,6 +129,7 @@ from .modeling_funnel import ( FunnelForMaskedLM, FunnelForMultipleChoice, + FunnelForPreTraining, FunnelForQuestionAnswering, FunnelForSequenceClassification, FunnelForTokenClassification, @@ -143,12 +145,13 @@ LongformerForTokenClassification, LongformerModel, ) -from .modeling_lxmert import LxmertForPreTraining, LxmertModel +from .modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from .modeling_marian import MarianMTModel from .modeling_mbart import MBartForConditionalGeneration from .modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, + MobileBertForNextSentencePrediction, MobileBertForPreTraining, MobileBertForQuestionAnswering, MobileBertForSequenceClassification, @@ -166,6 +169,7 @@ from .modeling_reformer import ( ReformerForMaskedLM, ReformerForQuestionAnswering, + ReformerForSequenceClassification, ReformerModel, ReformerModelWithLMHead, ) @@ -285,6 +289,7 @@ (CTRLConfig, CTRLLMHeadModel), (ElectraConfig, ElectraForPreTraining), (LxmertConfig, LxmertForPreTraining), + (FunnelConfig, FunnelForPreTraining), ] ) @@ -396,6 +401,7 @@ (DebertaConfig, DebertaForSequenceClassification), (GPT2Config, GPT2ForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), + (ReformerConfig, ReformerForSequenceClassification), ] ) @@ -417,6 +423,7 @@ (ElectraConfig, ElectraForQuestionAnswering), (ReformerConfig, ReformerForQuestionAnswering), (FunnelConfig, FunnelForQuestionAnswering), + (LxmertConfig, LxmertForQuestionAnswering), ] ) @@ -460,6 +467,13 @@ ] ) +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( + [ + (BertConfig, BertForNextSentencePrediction), + (MobileBertConfig, MobileBertForNextSentencePrediction), + ] +) + AUTO_MODEL_PRETRAINED_DOCSTRING = r""" The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either @@ -1519,3 +1533,103 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()), ) ) + + +class AutoModelForNextSentencePrediction: + r""" + This is a generic model class that will be instantiated as one of the model classes of the library---with a + multiple choice classification head---when created with the when created with the + :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` class method or the + :meth:`~transformers.AutoModelForNextSentencePrediction.from_config` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoModelForNextSentencePrediction is designed to be instantiated " + "using the `AutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or " + "`AutoModelForNextSentencePrediction.from_config(config)` methods." + ) + + @classmethod + @replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False) + def from_config(cls, config): + r""" + Instantiates one of the model classes of the library---with a multiple choice classification head---from a + configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` to load + the model weights. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + List options + + Examples:: + + >>> from transformers import AutoConfig, AutoModelForNextSentencePrediction + >>> # Download configuration from S3 and cache. + >>> config = AutoConfig.from_pretrained('bert-base-uncased') + >>> model = AutoModelForNextSentencePrediction.from_config(config) + """ + if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): + return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config) + + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), + ) + ) + + @classmethod + @replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING) + @add_start_docstrings( + "Instantiate one of the model classes of the library---with a multiple choice classification head---from a " + "pretrained model.", + AUTO_MODEL_PRETRAINED_DOCSTRING, + ) + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Examples:: + + >>> from transformers import AutoConfig, AutoModelForNextSentencePrediction + + >>> # Download model and configuration from S3 and cache. + >>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased') + + >>> # Update configuration during loading + >>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + >>> model = AutoModelForNextSentencePrediction.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) + + if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): + return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), + ) + ) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 35b03b73e47d..71d4a23dce6a 100755 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -1228,13 +1228,14 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, - next_sentence_label=None, + labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, + **kwargs ): r""" - next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: @@ -1255,10 +1256,18 @@ def forward( >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') - >>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1])) + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) >>> logits = outputs.logits >>> assert logits[0, 0] < logits[0, 1] # next sentence was random """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( @@ -1278,9 +1287,9 @@ def forward( seq_relationship_scores = self.cls(pooled_output) next_sentence_loss = None - if next_sentence_label is not None: + if labels is not None: loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) if not return_dict: output = (seq_relationship_scores,) + outputs[2:] diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 838ea248f172..442f78ec434f 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -1069,7 +1069,7 @@ def forward( if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() - loss = loss_fct(pooled_logits.view(-1), labels.view(-1)) + loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 6e468623cca5..950dd0da4448 100755 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -1069,7 +1069,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, output_attentions) + return module(*inputs, is_global_attn) return custom_forward @@ -1079,7 +1079,6 @@ def custom_forward(*inputs): attention_mask, is_index_masked, is_index_global_attn, - is_global_attn, ) else: layer_outputs = layer_module( diff --git a/src/transformers/modeling_lxmert.py b/src/transformers/modeling_lxmert.py index cbca95d160bb..ca49cf993471 100644 --- a/src/transformers/modeling_lxmert.py +++ b/src/transformers/modeling_lxmert.py @@ -17,6 +17,7 @@ import math import os +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -1154,16 +1155,17 @@ def forward( visual_attention_mask=None, token_type_ids=None, inputs_embeds=None, - masked_lm_labels=None, + labels=None, obj_labels=None, matched_label=None, ans=None, output_attentions=None, output_hidden_states=None, return_dict=None, + **kwargs, ): r""" - masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`): + labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` @@ -1183,6 +1185,15 @@ def forward( Returns: """ + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("masked_lm_labels") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + device = input_ids.device if input_ids is not None else inputs_embeds.device lxmert_output = self.lxmert( input_ids=input_ids, @@ -1210,13 +1221,13 @@ def forward( total_loss = ( None - if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None) + if (labels is None and matched_label is None and obj_labels is None and ans is None) else torch.tensor(0.0, device=device) ) - if masked_lm_labels is not None and self.task_mask_lm: + if labels is not None and self.task_mask_lm: masked_lm_loss = self.loss_fcts["ce"]( lang_prediction_scores.view(-1, self.config.vocab_size), - masked_lm_labels.view(-1), + labels.view(-1), ) total_loss += masked_lm_loss if matched_label is not None and self.task_matched: @@ -1391,6 +1402,7 @@ def forward( Returns: """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict lxmert_output = self.lxmert( input_ids=input_ids, diff --git a/src/transformers/modeling_mobilebert.py b/src/transformers/modeling_mobilebert.py index 99110386c873..e890c30eb22b 100644 --- a/src/transformers/modeling_mobilebert.py +++ b/src/transformers/modeling_mobilebert.py @@ -1194,13 +1194,14 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, - next_sentence_label=None, + labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, + **kwargs, ): r""" - next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) Indices should be in ``[0, 1]``. @@ -1221,10 +1222,18 @@ def forward( >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') - >>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1])) + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) >>> loss = outputs.loss >>> logits = outputs.logits """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.mobilebert( @@ -1243,9 +1252,9 @@ def forward( seq_relationship_score = self.cls(pooled_output) next_sentence_loss = None - if next_sentence_label is not None: + if labels is not None: loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1)) if not return_dict: output = (seq_relationship_score,) + outputs[2:] diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index 5bfb2b682ad2..011fc5eb8cc8 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -824,7 +824,7 @@ def forward( if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() - loss = loss_fct(pooled_logits.view(-1), labels.view(-1)) + loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index ffd50e47c95f..adcc19c61be2 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -221,7 +221,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a f"Some weights of the PyTorch model were not used when " f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model trained on another task " - f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n" + f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect " f"to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model)." ) @@ -375,7 +375,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F f"Some weights of the TF 2.0 model were not used when " f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n" + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect " f"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)." ) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index ab3523b8724e..8c46bd59dcb9 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -730,7 +730,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b7a87f99a179..6b5e7bab1bd0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1047,7 +1047,7 @@ def load(module: nn.Module, prefix=""): f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b37290388126..ce4fcfc7ce01 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -256,6 +256,9 @@ def load_tf_weights_in_albert(*args, **kwargs): MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None + + MODEL_FOR_PRETRAINING_MAPPING = None @@ -313,6 +316,15 @@ def from_pretrained(self, *args, **kwargs): requires_pytorch(self) +class AutoModelForNextSentencePrediction: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class AutoModelForPreTraining: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index 9040e1a5484c..f3f2459b16b1 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -24,7 +24,10 @@ if is_torch_available(): + import torch + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, AlbertConfig, AlbertForMaskedLM, AlbertForMultipleChoice, @@ -227,6 +230,20 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): else () ) + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["sentence_order_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = AlbertModelTester(self) self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37) diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index c18cf5a7308b..1cc296714b5e 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -25,7 +25,10 @@ if is_torch_available(): + import torch + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, BertConfig, BertForMaskedLM, BertForMultipleChoice, @@ -268,7 +271,7 @@ def create_and_check_for_next_sequence_prediction( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - next_sentence_label=sequence_labels, + labels=sequence_labels, ) self.parent.assertEqual(result.logits.shape, (self.batch_size, 2)) @@ -377,6 +380,20 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["next_sentence_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = BertModelTester(self) self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 53f45a26f388..7435a6a1f5a5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -35,10 +35,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + MODEL_MAPPING, AdaptiveEmbedding, BertConfig, BertModel, @@ -88,7 +90,10 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict["end_positions"] = torch.zeros( self.model_tester.batch_size, dtype=torch.long, device=torch_device ) - elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): + elif model_class in [ + *MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(), + *MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(), + ]: inputs_dict["labels"] = torch.zeros( self.model_tester.batch_size, dtype=torch.long, device=torch_device ) @@ -204,6 +209,41 @@ def test_forward_signature(self): expected_arg_names = ["input_ids"] self.assertListEqual(arg_names[:1], expected_arg_names) + def test_training(self): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in MODEL_MAPPING.values(): + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"): + return + + config.gradient_checkpointing = True + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in MODEL_MAPPING.values(): + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/test_modeling_dpr.py b/tests/test_modeling_dpr.py index ad6b860288bc..07a21e00bf2c 100644 --- a/tests/test_modeling_dpr.py +++ b/tests/test_modeling_dpr.py @@ -38,7 +38,7 @@ def __init__( parent, batch_size=13, seq_length=7, - is_training=True, + is_training=False, use_input_mask=True, use_token_type_ids=True, use_labels=True, diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 29bc782f937f..340fbcd18023 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -24,7 +24,10 @@ if is_torch_available(): + import torch + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, ElectraConfig, ElectraForMaskedLM, ElectraForMultipleChoice, @@ -285,6 +288,17 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): else () ) + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = ElectraModelTester(self) self.config_tester = ConfigTester(self, config_class=ElectraConfig, hidden_size=37) diff --git a/tests/test_modeling_flaubert.py b/tests/test_modeling_flaubert.py index 6694d9c912e7..b5617a059147 100644 --- a/tests/test_modeling_flaubert.py +++ b/tests/test_modeling_flaubert.py @@ -24,6 +24,8 @@ if is_torch_available(): + import torch + from transformers import ( FlaubertConfig, FlaubertForMultipleChoice, @@ -343,6 +345,21 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): else () ) + # Flaubert has 2 QA models -> need to manually set the correct labels for one of them here + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "FlaubertForQuestionAnswering": + inputs_dict["start_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + inputs_dict["end_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + + return inputs_dict + def setUp(self): self.model_tester = FlaubertModelTester(self) self.config_tester = ConfigTester(self, config_class=FlaubertConfig, emb_dim=37) diff --git a/tests/test_modeling_funnel.py b/tests/test_modeling_funnel.py index 1b59cc93fb20..f3fd12e9378c 100644 --- a/tests/test_modeling_funnel.py +++ b/tests/test_modeling_funnel.py @@ -27,6 +27,7 @@ import torch from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, FunnelBaseModel, FunnelConfig, FunnelForMaskedLM, @@ -360,6 +361,17 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): else () ) + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = FunnelModelTester(self) self.config_tester = ConfigTester(self, config_class=FunnelConfig) @@ -415,6 +427,21 @@ def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + # overwrite from test_modeling_common + def test_training(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class.__name__ == "FunnelBaseModel": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + @require_torch @require_sentencepiece diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 97d5bec376e3..aa6133d35c27 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -388,6 +388,29 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () test_missing_keys = False + # special case for DoubleHeads model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "GPT2DoubleHeadsModel": + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["input_ids"] = inputs_dict["labels"] + inputs_dict["token_type_ids"] = inputs_dict["labels"] + inputs_dict["mc_token_ids"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["mc_labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = GPT2ModelTester(self) self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) diff --git a/tests/test_modeling_lxmert.py b/tests/test_modeling_lxmert.py index 56b68b92e448..e335603d71ab 100644 --- a/tests/test_modeling_lxmert.py +++ b/tests/test_modeling_lxmert.py @@ -14,6 +14,7 @@ # limitations under the License. +import copy import unittest from transformers import is_torch_available @@ -26,7 +27,14 @@ if is_torch_available(): import torch - from transformers import LxmertConfig, LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + LxmertConfig, + LxmertForPreTraining, + LxmertForQuestionAnswering, + LxmertModel, + ) from transformers.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST @@ -533,6 +541,22 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_torchscript = False + # overwrite function because qa models takes different input label shape + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = copy.deepcopy(inputs_dict) + + if return_labels: + if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + elif model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + # special case for models like BERT that use multi-loss training for PreTraining + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = LxmertModelTester(self) self.config_tester = ConfigTester(self, config_class=LxmertConfig, hidden_size=37) diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index c586858f102c..e1e3ad82d078 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -27,6 +27,7 @@ import torch from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, MobileBertConfig, MobileBertForMaskedLM, MobileBertForMultipleChoice, @@ -220,7 +221,7 @@ def create_and_check_mobilebert_for_next_sequence_prediction( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - next_sentence_label=sequence_labels, + labels=sequence_labels, ) self.parent.assertEqual(result.logits.shape, (self.batch_size, 2)) @@ -327,6 +328,20 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): else () ) + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["next_sentence_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = MobileBertModelTester(self) self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37) diff --git a/tests/test_modeling_openai.py b/tests/test_modeling_openai.py index eae027e7a0ac..75858a05498b 100644 --- a/tests/test_modeling_openai.py +++ b/tests/test_modeling_openai.py @@ -182,6 +182,29 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC (OpenAIGPTLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly + # special case for DoubleHeads model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "OpenAIGPTDoubleHeadsModel": + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["input_ids"] = inputs_dict["labels"] + inputs_dict["token_type_ids"] = inputs_dict["labels"] + inputs_dict["mc_token_ids"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["mc_labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + def setUp(self): self.model_tester = OpenAIGPTModelTester(self) self.config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index 0a967ce287a7..e8016c4282fb 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -1038,7 +1038,7 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix is_encoder_decoder = False def setUp(self): - self.model_tester = ProphetNetStandaloneDecoderModelTester(self) + self.model_tester = ProphetNetStandaloneDecoderModelTester(self, is_training=False) self.config_tester = ConfigTester(self, config_class=ProphetNetConfig) def test_config(self): @@ -1063,7 +1063,7 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase): is_encoder_decoder = False def setUp(self): - self.model_tester = ProphetNetStandaloneEncoderModelTester(self) + self.model_tester = ProphetNetStandaloneEncoderModelTester(self, is_training=False) self.config_tester = ConfigTester(self, config_class=ProphetNetConfig) def test_config(self): diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index ad1b8aed7e03..7f6478e3a7ec 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -42,7 +42,7 @@ def __init__( self.mem_len = 30 self.key_length = self.seq_length + self.mem_len self.clamp_len = 15 - self.is_training = True + self.is_training = False self.use_labels = True self.vocab_size = 99 self.cutoffs = [10, 50, 80] diff --git a/tests/test_modeling_xlm.py b/tests/test_modeling_xlm.py index 14c3236ef974..852a6a4e0544 100644 --- a/tests/test_modeling_xlm.py +++ b/tests/test_modeling_xlm.py @@ -351,6 +351,21 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): (XLMWithLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + # XLM has 2 QA models -> need to manually set the correct labels for one of them here + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "XLMForQuestionAnswering": + inputs_dict["start_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + inputs_dict["end_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + + return inputs_dict + def setUp(self): self.model_tester = XLMModelTester(self) self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37) diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index c154874a1ebe..9bd81f9b9e2e 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -499,6 +499,21 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ) # TODO (PVP): Check other models whether language generation is also applicable test_pruning = False + # XLNet has 2 QA models -> need to manually set the correct labels for one of them here + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "XLNetForQuestionAnswering": + inputs_dict["start_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + inputs_dict["end_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + + return inputs_dict + def setUp(self): self.model_tester = XLNetModelTester(self) self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)