-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Tests] Add Common Test for Training + Fix a couple of bugs #8415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
20d582b
0eade91
5256437
7051a8e
04e3d2f
f1beb2d
a1744c9
ae75f7d
c6426f8
96374f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,6 +77,7 @@ | |
| from .modeling_bert import ( | ||
| BertForMaskedLM, | ||
| BertForMultipleChoice, | ||
| BertForNextSentencePrediction, | ||
| BertForPreTraining, | ||
| BertForQuestionAnswering, | ||
| BertForSequenceClassification, | ||
|
|
@@ -128,6 +129,7 @@ | |
| from .modeling_funnel import ( | ||
| FunnelForMaskedLM, | ||
| FunnelForMultipleChoice, | ||
| FunnelForPreTraining, | ||
| FunnelForQuestionAnswering, | ||
| FunnelForSequenceClassification, | ||
| FunnelForTokenClassification, | ||
|
|
@@ -143,12 +145,13 @@ | |
| LongformerForTokenClassification, | ||
| LongformerModel, | ||
| ) | ||
| from .modeling_lxmert import LxmertForPreTraining, LxmertModel | ||
| from .modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel | ||
| from .modeling_marian import MarianMTModel | ||
| from .modeling_mbart import MBartForConditionalGeneration | ||
| from .modeling_mobilebert import ( | ||
| MobileBertForMaskedLM, | ||
| MobileBertForMultipleChoice, | ||
| MobileBertForNextSentencePrediction, | ||
| MobileBertForPreTraining, | ||
| MobileBertForQuestionAnswering, | ||
| MobileBertForSequenceClassification, | ||
|
|
@@ -166,6 +169,7 @@ | |
| from .modeling_reformer import ( | ||
| ReformerForMaskedLM, | ||
| ReformerForQuestionAnswering, | ||
| ReformerForSequenceClassification, | ||
| ReformerModel, | ||
| ReformerModelWithLMHead, | ||
| ) | ||
|
|
@@ -285,6 +289,7 @@ | |
| (CTRLConfig, CTRLLMHeadModel), | ||
| (ElectraConfig, ElectraForPreTraining), | ||
| (LxmertConfig, LxmertForPreTraining), | ||
| (FunnelConfig, FunnelForPreTraining), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was missing
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! |
||
| ] | ||
| ) | ||
|
|
||
|
|
@@ -396,6 +401,7 @@ | |
| (DebertaConfig, DebertaForSequenceClassification), | ||
| (GPT2Config, GPT2ForSequenceClassification), | ||
| (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), | ||
| (ReformerConfig, ReformerForSequenceClassification), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was missing
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another good catch! |
||
| ] | ||
| ) | ||
|
|
||
|
|
@@ -417,6 +423,7 @@ | |
| (ElectraConfig, ElectraForQuestionAnswering), | ||
| (ReformerConfig, ReformerForQuestionAnswering), | ||
| (FunnelConfig, FunnelForQuestionAnswering), | ||
| (LxmertConfig, LxmertForQuestionAnswering), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was missing
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We definitely need some kind of script to check those automatically ;-)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we do! |
||
| ] | ||
| ) | ||
|
|
||
|
|
@@ -460,6 +467,13 @@ | |
| ] | ||
| ) | ||
|
|
||
| MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( | ||
| [ | ||
| (BertConfig, BertForNextSentencePrediction), | ||
| (MobileBertConfig, MobileBertForNextSentencePrediction), | ||
| ] | ||
| ) | ||
|
|
||
| AUTO_MODEL_PRETRAINED_DOCSTRING = r""" | ||
|
|
||
| The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | ||
|
|
@@ -1519,3 +1533,103 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()), | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| class AutoModelForNextSentencePrediction: | ||
| r""" | ||
| This is a generic model class that will be instantiated as one of the model classes of the library---with a | ||
| multiple choice classification head---when created with the when created with the | ||
| :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` class method or the | ||
| :meth:`~transformers.AutoModelForNextSentencePrediction.from_config` class method. | ||
|
|
||
| This class cannot be instantiated directly using ``__init__()`` (throws an error). | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| raise EnvironmentError( | ||
| "AutoModelForNextSentencePrediction is designed to be instantiated " | ||
| "using the `AutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or " | ||
| "`AutoModelForNextSentencePrediction.from_config(config)` methods." | ||
| ) | ||
|
|
||
| @classmethod | ||
| @replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False) | ||
| def from_config(cls, config): | ||
| r""" | ||
| Instantiates one of the model classes of the library---with a multiple choice classification head---from a | ||
| configuration. | ||
|
|
||
| Note: | ||
| Loading a model from its configuration file does **not** load the model weights. It only affects the | ||
| model's configuration. Use :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` to load | ||
| the model weights. | ||
|
|
||
| Args: | ||
| config (:class:`~transformers.PretrainedConfig`): | ||
| The model class to instantiate is selected based on the configuration class: | ||
|
|
||
| List options | ||
|
|
||
| Examples:: | ||
|
|
||
| >>> from transformers import AutoConfig, AutoModelForNextSentencePrediction | ||
| >>> # Download configuration from S3 and cache. | ||
| >>> config = AutoConfig.from_pretrained('bert-base-uncased') | ||
| >>> model = AutoModelForNextSentencePrediction.from_config(config) | ||
| """ | ||
| if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): | ||
| return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config) | ||
|
|
||
| raise ValueError( | ||
| "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" | ||
| "Model type should be one of {}.".format( | ||
| config.__class__, | ||
| cls.__name__, | ||
| ", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), | ||
| ) | ||
| ) | ||
|
|
||
| @classmethod | ||
| @replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING) | ||
| @add_start_docstrings( | ||
| "Instantiate one of the model classes of the library---with a multiple choice classification head---from a " | ||
| "pretrained model.", | ||
| AUTO_MODEL_PRETRAINED_DOCSTRING, | ||
| ) | ||
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||
| r""" | ||
| Examples:: | ||
|
|
||
| >>> from transformers import AutoConfig, AutoModelForNextSentencePrediction | ||
|
|
||
| >>> # Download model and configuration from S3 and cache. | ||
| >>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased') | ||
|
|
||
| >>> # Update configuration during loading | ||
| >>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True) | ||
| >>> model.config.output_attentions | ||
| True | ||
|
|
||
| >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) | ||
| >>> config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') | ||
| >>> model = AutoModelForNextSentencePrediction.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) | ||
| """ | ||
| config = kwargs.pop("config", None) | ||
| if not isinstance(config, PretrainedConfig): | ||
| config, kwargs = AutoConfig.from_pretrained( | ||
| pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs | ||
| ) | ||
|
|
||
| if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): | ||
| return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained( | ||
| pretrained_model_name_or_path, *model_args, config=config, **kwargs | ||
| ) | ||
|
|
||
| raise ValueError( | ||
| "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" | ||
| "Model type should be one of {}.".format( | ||
| config.__class__, | ||
| cls.__name__, | ||
| ", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), | ||
| ) | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1228,13 +1228,14 @@ def forward( | |
| position_ids=None, | ||
| head_mask=None, | ||
| inputs_embeds=None, | ||
| next_sentence_label=None, | ||
| labels=None, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if model has only one loss as is the case here, I think we should force the name to be
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with that approach. This might require special new dataloaders to change the labels name, but this is easy enough to do. |
||
| output_attentions=None, | ||
| output_hidden_states=None, | ||
| return_dict=None, | ||
| **kwargs | ||
| ): | ||
| r""" | ||
| next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
| Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair | ||
| (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: | ||
|
|
||
|
|
@@ -1255,10 +1256,18 @@ def forward( | |
| >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." | ||
| >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') | ||
|
|
||
| >>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1])) | ||
| >>> outputs = model(**encoding, labels=torch.LongTensor([1])) | ||
| >>> logits = outputs.logits | ||
| >>> assert logits[0, 0] < logits[0, 1] # next sentence was random | ||
| """ | ||
|
|
||
| if "next_sentence_label" in kwargs: | ||
| warnings.warn( | ||
| "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", | ||
| FutureWarning, | ||
| ) | ||
| labels = kwargs.pop("next_sentence_label") | ||
|
|
||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| outputs = self.bert( | ||
|
|
@@ -1278,9 +1287,9 @@ def forward( | |
| seq_relationship_scores = self.cls(pooled_output) | ||
|
|
||
| next_sentence_loss = None | ||
| if next_sentence_label is not None: | ||
| if labels is not None: | ||
| loss_fct = CrossEntropyLoss() | ||
| next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1)) | ||
| next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) | ||
|
|
||
| if not return_dict: | ||
| output = (seq_relationship_scores,) + outputs[2:] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1069,7 +1069,7 @@ def forward( | |
| if self.num_labels == 1: | ||
| # We are doing regression | ||
| loss_fct = MSELoss() | ||
| loss = loss_fct(pooled_logits.view(-1), labels.view(-1)) | ||
| loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| else: | ||
| loss_fct = CrossEntropyLoss() | ||
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1069,7 +1069,7 @@ def forward( | |
|
|
||
| def create_custom_forward(module): | ||
| def custom_forward(*inputs): | ||
| return module(*inputs, output_attentions) | ||
| return module(*inputs, is_global_attn) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fixes the longformer bug. |
||
|
|
||
| return custom_forward | ||
|
|
||
|
|
@@ -1079,7 +1079,6 @@ def custom_forward(*inputs): | |
| attention_mask, | ||
| is_index_masked, | ||
| is_index_global_attn, | ||
| is_global_attn, | ||
| ) | ||
| else: | ||
| layer_outputs = layer_module( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only have two model types in the library that have a multi-loss model: 1) Some
ForPreTrainingModelsand the GPTDoubleHeadModels. I think it's cleaner if we force the loss names to be consistent here ->sentence_order_labelis renamed tonext_sentence_labelas is done forBert. The old argument is deprecated.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmm, this one has a different name than
next_sentence_labelbecause it is a different objective if my memory serves me right. So having a different name here does not seem like a bad idea.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the SOP objective can still be represented by the
next_sentence_label, as it can indicate whether the next sequence follows or precedes the first sequence.However,
sentence_order_labelis a good name as it is, so why are we changing this? For consistency?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with you! Albert's Pretraining is slightly different to BERT's, so different names make sense here. Reverting the renaming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought it's the exact same objective, but it's not