Skip to content

Commit ac8473f

Browse files
jvcopJean Vancoppenoletgatrtrtobiguepatrickvonplaten
authored andcommitted
Add support for pretraining recurring span selection to Splinter (huggingface#17247)
* Add SplinterForSpanSelection for pre-training recurring span selection. * Formatting. * Rename SplinterForSpanSelection to SplinterForPreTraining. * Ensure repo consistency * Fixup changes * Address SplinterForPreTraining PR comments * Incorporate feedback and derive multiple question tokens per example. * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Jean Vancoppenole <[email protected]> Co-authored-by: Tobias Günther <[email protected]> Co-authored-by: Tobias Günther <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 92f016e commit ac8473f

File tree

6 files changed

+435
-18
lines changed

6 files changed

+435
-18
lines changed

docs/source/en/model_doc/splinter.mdx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,8 @@ This model was contributed by [yuvalkirstain](https://huggingface.co/yuvalkirsta
7272

7373
[[autodoc]] SplinterForQuestionAnswering
7474
- forward
75+
76+
## SplinterForPreTraining
77+
78+
[[autodoc]] SplinterForPreTraining
79+
- forward

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,7 @@
15321532
_import_structure["models.splinter"].extend(
15331533
[
15341534
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
1535+
"SplinterForPreTraining",
15351536
"SplinterForQuestionAnswering",
15361537
"SplinterLayer",
15371538
"SplinterModel",
@@ -3830,6 +3831,7 @@
38303831
from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel
38313832
from .models.splinter import (
38323833
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
3834+
SplinterForPreTraining,
38333835
SplinterForQuestionAnswering,
38343836
SplinterLayer,
38353837
SplinterModel,

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
("openai-gpt", "OpenAIGPTLMHeadModel"),
162162
("retribert", "RetriBertModel"),
163163
("roberta", "RobertaForMaskedLM"),
164+
("splinter", "SplinterForPreTraining"),
164165
("squeezebert", "SqueezeBertForMaskedLM"),
165166
("t5", "T5ForConditionalGeneration"),
166167
("tapas", "TapasForMaskedLM"),

src/transformers/models/splinter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_import_structure["modeling_splinter"] = [
4343
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
4444
"SplinterForQuestionAnswering",
45+
"SplinterForPreTraining",
4546
"SplinterLayer",
4647
"SplinterModel",
4748
"SplinterPreTrainedModel",
@@ -68,6 +69,7 @@
6869
else:
6970
from .modeling_splinter import (
7071
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
72+
SplinterForPreTraining,
7173
SplinterForQuestionAnswering,
7274
SplinterLayer,
7375
SplinterModel,

src/transformers/models/splinter/modeling_splinter.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
import math
19+
from dataclasses import dataclass
1920
from typing import List, Optional, Tuple, Union
2021

2122
import torch
@@ -24,7 +25,7 @@
2425
from torch.nn import CrossEntropyLoss
2526

2627
from ...activations import ACT2FN
27-
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
28+
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
2829
from ...modeling_utils import PreTrainedModel
2930
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
3031
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
@@ -940,3 +941,171 @@ def forward(
940941
hidden_states=outputs.hidden_states,
941942
attentions=outputs.attentions,
942943
)
944+
945+
946+
@dataclass
947+
class SplinterForPreTrainingOutput(ModelOutput):
948+
"""
949+
Class for outputs of Splinter as a span selection model.
950+
951+
Args:
952+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
953+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
954+
start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
955+
Span-start scores (before SoftMax).
956+
end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
957+
Span-end scores (before SoftMax).
958+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
959+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
960+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
961+
962+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
963+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
964+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
965+
sequence_length)`.
966+
967+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
968+
heads.
969+
"""
970+
971+
loss: Optional[torch.FloatTensor] = None
972+
start_logits: torch.FloatTensor = None
973+
end_logits: torch.FloatTensor = None
974+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
975+
attentions: Optional[Tuple[torch.FloatTensor]] = None
976+
977+
978+
@add_start_docstrings(
979+
"""
980+
Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
981+
is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
982+
instead.
983+
""",
984+
SPLINTER_START_DOCSTRING,
985+
)
986+
class SplinterForPreTraining(SplinterPreTrainedModel):
987+
def __init__(self, config):
988+
super().__init__(config)
989+
990+
self.splinter = SplinterModel(config)
991+
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
992+
self.question_token_id = config.question_token_id
993+
994+
# Initialize weights and apply final processing
995+
self.post_init()
996+
997+
@add_start_docstrings_to_model_forward(
998+
SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length")
999+
)
1000+
def forward(
1001+
self,
1002+
input_ids: Optional[torch.Tensor] = None,
1003+
attention_mask: Optional[torch.Tensor] = None,
1004+
token_type_ids: Optional[torch.Tensor] = None,
1005+
position_ids: Optional[torch.Tensor] = None,
1006+
head_mask: Optional[torch.Tensor] = None,
1007+
inputs_embeds: Optional[torch.Tensor] = None,
1008+
start_positions: Optional[torch.LongTensor] = None,
1009+
end_positions: Optional[torch.LongTensor] = None,
1010+
output_attentions: Optional[bool] = None,
1011+
output_hidden_states: Optional[bool] = None,
1012+
return_dict: Optional[bool] = None,
1013+
question_positions: Optional[torch.LongTensor] = None,
1014+
) -> Union[Tuple, SplinterForPreTrainingOutput]:
1015+
r"""
1016+
start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
1017+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
1018+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1019+
are not taken into account for computing the loss.
1020+
end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
1021+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
1022+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1023+
are not taken into account for computing the loss.
1024+
question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
1025+
The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
1026+
num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
1027+
the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
1028+
sequence_length)`.
1029+
"""
1030+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1031+
1032+
if question_positions is None and start_positions is not None and end_positions is not None:
1033+
raise TypeError("question_positions must be specified in order to calculate the loss")
1034+
1035+
elif question_positions is None and input_ids is None:
1036+
raise TypeError("question_positions must be specified when input_embeds is used")
1037+
1038+
elif question_positions is None:
1039+
question_positions = self._prepare_question_positions(input_ids)
1040+
1041+
outputs = self.splinter(
1042+
input_ids,
1043+
attention_mask=attention_mask,
1044+
token_type_ids=token_type_ids,
1045+
position_ids=position_ids,
1046+
head_mask=head_mask,
1047+
inputs_embeds=inputs_embeds,
1048+
output_attentions=output_attentions,
1049+
output_hidden_states=output_hidden_states,
1050+
return_dict=return_dict,
1051+
)
1052+
1053+
sequence_output = outputs[0]
1054+
batch_size, sequence_length, dim = sequence_output.size()
1055+
# [batch_size, num_questions, sequence_length]
1056+
start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
1057+
1058+
num_questions = question_positions.size(1)
1059+
if attention_mask is not None:
1060+
attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
1061+
batch_size, num_questions, sequence_length
1062+
)
1063+
start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0
1064+
end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0
1065+
1066+
total_loss = None
1067+
# [batch_size, num_questions, sequence_length]
1068+
if start_positions is not None and end_positions is not None:
1069+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
1070+
start_positions.clamp_(0, max(0, sequence_length - 1))
1071+
end_positions.clamp_(0, max(0, sequence_length - 1))
1072+
1073+
# Ignore zero positions in the loss. Splinter never predicts zero
1074+
# during pretraining and zero is used for padding question
1075+
# tokens as well as for start and end positions of padded
1076+
# question tokens.
1077+
loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
1078+
start_loss = loss_fct(
1079+
start_logits.view(batch_size * num_questions, sequence_length),
1080+
start_positions.view(batch_size * num_questions),
1081+
)
1082+
end_loss = loss_fct(
1083+
end_logits.view(batch_size * num_questions, sequence_length),
1084+
end_positions.view(batch_size * num_questions),
1085+
)
1086+
total_loss = (start_loss + end_loss) / 2
1087+
1088+
if not return_dict:
1089+
output = (start_logits, end_logits) + outputs[1:]
1090+
return ((total_loss,) + output) if total_loss is not None else output
1091+
1092+
return SplinterForPreTrainingOutput(
1093+
loss=total_loss,
1094+
start_logits=start_logits,
1095+
end_logits=end_logits,
1096+
hidden_states=outputs.hidden_states,
1097+
attentions=outputs.attentions,
1098+
)
1099+
1100+
def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
1101+
rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
1102+
num_questions = torch.bincount(rows)
1103+
positions = torch.full(
1104+
(input_ids.size(0), num_questions.max()),
1105+
self.config.pad_token_id,
1106+
dtype=torch.long,
1107+
device=input_ids.device,
1108+
)
1109+
cols = torch.cat([torch.arange(n) for n in num_questions])
1110+
positions[rows, cols] = flat_positions
1111+
return positions

0 commit comments

Comments
 (0)