|
16 | 16 |
|
17 | 17 |
|
18 | 18 | import math |
| 19 | +from dataclasses import dataclass |
19 | 20 | from typing import List, Optional, Tuple, Union |
20 | 21 |
|
21 | 22 | import torch |
|
24 | 25 | from torch.nn import CrossEntropyLoss |
25 | 26 |
|
26 | 27 | from ...activations import ACT2FN |
27 | | -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput |
| 28 | +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput |
28 | 29 | from ...modeling_utils import PreTrainedModel |
29 | 30 | from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer |
30 | 31 | from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging |
@@ -940,3 +941,171 @@ def forward( |
940 | 941 | hidden_states=outputs.hidden_states, |
941 | 942 | attentions=outputs.attentions, |
942 | 943 | ) |
| 944 | + |
| 945 | + |
| 946 | +@dataclass |
| 947 | +class SplinterForPreTrainingOutput(ModelOutput): |
| 948 | + """ |
| 949 | + Class for outputs of Splinter as a span selection model. |
| 950 | +
|
| 951 | + Args: |
| 952 | + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided): |
| 953 | + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. |
| 954 | + start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`): |
| 955 | + Span-start scores (before SoftMax). |
| 956 | + end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`): |
| 957 | + Span-end scores (before SoftMax). |
| 958 | + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| 959 | + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| 960 | + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| 961 | +
|
| 962 | + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| 963 | + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| 964 | + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| 965 | + sequence_length)`. |
| 966 | +
|
| 967 | + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| 968 | + heads. |
| 969 | + """ |
| 970 | + |
| 971 | + loss: Optional[torch.FloatTensor] = None |
| 972 | + start_logits: torch.FloatTensor = None |
| 973 | + end_logits: torch.FloatTensor = None |
| 974 | + hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| 975 | + attentions: Optional[Tuple[torch.FloatTensor]] = None |
| 976 | + |
| 977 | + |
| 978 | +@add_start_docstrings( |
| 979 | + """ |
| 980 | + Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task |
| 981 | + is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans |
| 982 | + instead. |
| 983 | + """, |
| 984 | + SPLINTER_START_DOCSTRING, |
| 985 | +) |
| 986 | +class SplinterForPreTraining(SplinterPreTrainedModel): |
| 987 | + def __init__(self, config): |
| 988 | + super().__init__(config) |
| 989 | + |
| 990 | + self.splinter = SplinterModel(config) |
| 991 | + self.splinter_qass = QuestionAwareSpanSelectionHead(config) |
| 992 | + self.question_token_id = config.question_token_id |
| 993 | + |
| 994 | + # Initialize weights and apply final processing |
| 995 | + self.post_init() |
| 996 | + |
| 997 | + @add_start_docstrings_to_model_forward( |
| 998 | + SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length") |
| 999 | + ) |
| 1000 | + def forward( |
| 1001 | + self, |
| 1002 | + input_ids: Optional[torch.Tensor] = None, |
| 1003 | + attention_mask: Optional[torch.Tensor] = None, |
| 1004 | + token_type_ids: Optional[torch.Tensor] = None, |
| 1005 | + position_ids: Optional[torch.Tensor] = None, |
| 1006 | + head_mask: Optional[torch.Tensor] = None, |
| 1007 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 1008 | + start_positions: Optional[torch.LongTensor] = None, |
| 1009 | + end_positions: Optional[torch.LongTensor] = None, |
| 1010 | + output_attentions: Optional[bool] = None, |
| 1011 | + output_hidden_states: Optional[bool] = None, |
| 1012 | + return_dict: Optional[bool] = None, |
| 1013 | + question_positions: Optional[torch.LongTensor] = None, |
| 1014 | + ) -> Union[Tuple, SplinterForPreTrainingOutput]: |
| 1015 | + r""" |
| 1016 | + start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): |
| 1017 | + Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| 1018 | + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| 1019 | + are not taken into account for computing the loss. |
| 1020 | + end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): |
| 1021 | + Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| 1022 | + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| 1023 | + are not taken into account for computing the loss. |
| 1024 | + question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): |
| 1025 | + The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size, |
| 1026 | + num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be |
| 1027 | + the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size, |
| 1028 | + sequence_length)`. |
| 1029 | + """ |
| 1030 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 1031 | + |
| 1032 | + if question_positions is None and start_positions is not None and end_positions is not None: |
| 1033 | + raise TypeError("question_positions must be specified in order to calculate the loss") |
| 1034 | + |
| 1035 | + elif question_positions is None and input_ids is None: |
| 1036 | + raise TypeError("question_positions must be specified when input_embeds is used") |
| 1037 | + |
| 1038 | + elif question_positions is None: |
| 1039 | + question_positions = self._prepare_question_positions(input_ids) |
| 1040 | + |
| 1041 | + outputs = self.splinter( |
| 1042 | + input_ids, |
| 1043 | + attention_mask=attention_mask, |
| 1044 | + token_type_ids=token_type_ids, |
| 1045 | + position_ids=position_ids, |
| 1046 | + head_mask=head_mask, |
| 1047 | + inputs_embeds=inputs_embeds, |
| 1048 | + output_attentions=output_attentions, |
| 1049 | + output_hidden_states=output_hidden_states, |
| 1050 | + return_dict=return_dict, |
| 1051 | + ) |
| 1052 | + |
| 1053 | + sequence_output = outputs[0] |
| 1054 | + batch_size, sequence_length, dim = sequence_output.size() |
| 1055 | + # [batch_size, num_questions, sequence_length] |
| 1056 | + start_logits, end_logits = self.splinter_qass(sequence_output, question_positions) |
| 1057 | + |
| 1058 | + num_questions = question_positions.size(1) |
| 1059 | + if attention_mask is not None: |
| 1060 | + attention_mask_for_each_question = attention_mask.unsqueeze(1).expand( |
| 1061 | + batch_size, num_questions, sequence_length |
| 1062 | + ) |
| 1063 | + start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0 |
| 1064 | + end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0 |
| 1065 | + |
| 1066 | + total_loss = None |
| 1067 | + # [batch_size, num_questions, sequence_length] |
| 1068 | + if start_positions is not None and end_positions is not None: |
| 1069 | + # sometimes the start/end positions are outside our model inputs, we ignore these terms |
| 1070 | + start_positions.clamp_(0, max(0, sequence_length - 1)) |
| 1071 | + end_positions.clamp_(0, max(0, sequence_length - 1)) |
| 1072 | + |
| 1073 | + # Ignore zero positions in the loss. Splinter never predicts zero |
| 1074 | + # during pretraining and zero is used for padding question |
| 1075 | + # tokens as well as for start and end positions of padded |
| 1076 | + # question tokens. |
| 1077 | + loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
| 1078 | + start_loss = loss_fct( |
| 1079 | + start_logits.view(batch_size * num_questions, sequence_length), |
| 1080 | + start_positions.view(batch_size * num_questions), |
| 1081 | + ) |
| 1082 | + end_loss = loss_fct( |
| 1083 | + end_logits.view(batch_size * num_questions, sequence_length), |
| 1084 | + end_positions.view(batch_size * num_questions), |
| 1085 | + ) |
| 1086 | + total_loss = (start_loss + end_loss) / 2 |
| 1087 | + |
| 1088 | + if not return_dict: |
| 1089 | + output = (start_logits, end_logits) + outputs[1:] |
| 1090 | + return ((total_loss,) + output) if total_loss is not None else output |
| 1091 | + |
| 1092 | + return SplinterForPreTrainingOutput( |
| 1093 | + loss=total_loss, |
| 1094 | + start_logits=start_logits, |
| 1095 | + end_logits=end_logits, |
| 1096 | + hidden_states=outputs.hidden_states, |
| 1097 | + attentions=outputs.attentions, |
| 1098 | + ) |
| 1099 | + |
| 1100 | + def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 1101 | + rows, flat_positions = torch.where(input_ids == self.config.question_token_id) |
| 1102 | + num_questions = torch.bincount(rows) |
| 1103 | + positions = torch.full( |
| 1104 | + (input_ids.size(0), num_questions.max()), |
| 1105 | + self.config.pad_token_id, |
| 1106 | + dtype=torch.long, |
| 1107 | + device=input_ids.device, |
| 1108 | + ) |
| 1109 | + cols = torch.cat([torch.arange(n) for n in num_questions]) |
| 1110 | + positions[rows, cols] = flat_positions |
| 1111 | + return positions |
0 commit comments