Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 79 additions & 5 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@
]


class TFBertPreTrainingLoss:
"""
Loss function suitable for BERT-like pre-training, that is, the task of pretraining a language model by combining
NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
computation.
"""

def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(labels["labels"], (-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tf.reshape(logits[0], (-1, shape_list(logits[0])[2])),
masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(tf.reshape(labels["labels"], (-1,)), masked_lm_active_loss)
next_sentence_active_loss = tf.not_equal(tf.reshape(labels["next_sentence_label"], (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits[1], (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(
tf.reshape(labels["next_sentence_label"], (-1,)), mask=next_sentence_active_loss
)
masked_lm_loss = loss_fn(masked_lm_labels, masked_lm_reduced_logits)
next_sentence_loss = loss_fn(next_sentence_label, next_sentence_reduced_logits)
masked_lm_loss = tf.reshape(masked_lm_loss, (-1, shape_list(next_sentence_loss)[0]))
masked_lm_loss = tf.reduce_mean(masked_lm_loss, 0)

return masked_lm_loss + next_sentence_loss


class TFBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""

Expand Down Expand Up @@ -688,6 +720,7 @@ class TFBertForPreTrainingOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
prediction_logits: tf.Tensor = None
seq_relationship_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
Expand Down Expand Up @@ -814,7 +847,7 @@ def call(self, inputs, **kwargs):
""",
BERT_START_DOCSTRING,
)
class TFBertForPreTraining(TFBertPreTrainedModel):
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand All @@ -827,7 +860,21 @@ def get_output_embeddings(self):

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
next_sentence_label=None,
training=False,
):
r"""
Return:

Expand All @@ -843,17 +890,44 @@ def call(self, inputs, **kwargs):
>>> prediction_scores, seq_relationship_scores = outputs[:2]

"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)

if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
Comment on lines -846 to +902
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that this seems to be an issue in all TF models: if return_dict is defined in the inputs (they're either a tuple, a list or a dict), the value of return_dict won't be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! But will be solved with the new input parsing. A fix will be in a PR that will arrive soon.


outputs = self.bert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
prediction_scores = self.mlm(sequence_output, training=training)
seq_relationship_score = self.nsp(pooled_output)
total_loss = None

if labels is not None and next_sentence_label is not None:
d_labels = {"labels": labels}
d_labels["next_sentence_label"] = next_sentence_label
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))

if not return_dict:
return (prediction_scores, seq_relationship_score) + outputs[2:]

return TFBertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
if is_tf_available():
import tensorflow as tf

from transformers import TF_MODEL_FOR_PRETRAINING_MAPPING
from transformers.modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
Expand Down Expand Up @@ -274,6 +275,16 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)

# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)

return inputs_dict

def setUp(self):
self.model_tester = TFBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
Expand Down
27 changes: 19 additions & 8 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -102,6 +103,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
]:
inputs_dict["labels"] = tf.zeros(
Expand Down Expand Up @@ -834,7 +836,9 @@ def test_loss_computation(self):
if getattr(model, "compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)

if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
Expand All @@ -859,23 +863,30 @@ def test_loss_computation(self):

# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.getfullargspec(model.call)[0]
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())

# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {1: "input_ids"}
tuple_index_mapping = {0: "input_ids"}
for label_key in label_keys:
label_key_index = signature.index(label_key)
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []

for name in signature_names:
if name != "kwargs":
list_input.append(signature[name].default)

# Initialize a list with None, update the values and convert to a tuple
list_input = [None] * sorted_tuple_index_mapping[-1][0]
for index, value in sorted_tuple_index_mapping:
list_input[index - 1] = prepared_for_class[value]
list_input[index] = prepared_for_class[value]

tuple_input = tuple(list_input)

# Send to model
loss = model(tuple_input)[0]
loss = model(tuple_input[:-1])[0]

self.assertEqual(loss.shape, [loss_size])

def _generate_random_bad_tokens(self, num_bad_tokens, model):
Expand Down