From f5ccdca1125d9b83c5e370475f22b3a843bb5289 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:30:36 +0000 Subject: [PATCH 01/19] fix longformer --- src/transformers/modeling_longformer.py | 220 ++++++++++++++++++++---- 1 file changed, 190 insertions(+), 30 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 9d869e73a1c5..4daed3c27bd7 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -25,8 +25,9 @@ from .configuration_longformer import LongformerConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_bert import BertPreTrainedModel -from .modeling_roberta import RobertaLMHead, RobertaModel +from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput +from .modeling_roberta import RobertaEmbeddings, RobertaLMHead +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer logger = logging.getLogger(__name__) @@ -237,13 +238,7 @@ def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int) return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2) def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions=False, + self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, ): """ LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. @@ -254,11 +249,7 @@ def forward( 0: local attention +ve: global attention - `encoder_hidden_states` and `encoder_attention_mask` are not supported and should be None """ - # TODO: add support for `encoder_hidden_states` and `encoder_attention_mask` - assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" - assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None" if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) @@ -435,6 +426,129 @@ def forward( return outputs +class LongformerAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.self = LongformerSelfAttention(config, layer_id) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, + ): + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions,) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class LongformerLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.attention = LongformerAttention(config, layer_id) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, + ): + self_attention_outputs = self.attention( + hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class LongformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + + def forward( + self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, attention_mask, head_mask[i], + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions,) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class LongformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained + models. + """ + + config_class = LongformerConfig + base_model_prefix = "longformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + LONGFORMER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `__ sub-class. @@ -498,7 +612,7 @@ def forward( "The bare Longformer Model outputting raw hidden-states without any specific head on top.", LONGFORMER_START_DOCSTRING, ) -class LongformerModel(RobertaModel): +class LongformerModel(LongformerPreTrainedModel): """ This class overrides :class:`~transformers.RobertaModel` to provide the ability to process long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer @@ -530,12 +644,26 @@ def __init__(self, config): f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" ) - for i, layer in enumerate(self.encoder.layer): - # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention` - layer.attention.self = LongformerSelfAttention(config, layer_id=i) + self.embeddings = RobertaEmbeddings(config) + self.encoder = LongformerEncoder(config) + self.pooler = BertPooler(config) self.init_weights() + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + def _pad_to_window_size( self, input_ids: torch.Tensor, @@ -587,6 +715,7 @@ def forward( global_attention_mask=None, token_type_ids=None, position_ids=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -634,6 +763,22 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + # padding attention_window = ( self.config.attention_window @@ -663,19 +808,34 @@ def forward( pad_token_id=self.config.pad_token_id, ) - # embed - output = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=None, - inputs_embeds=inputs_embeds, - encoder_hidden_states=None, - encoder_attention_mask=None, + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here # undo padding if padding_len > 0: @@ -684,13 +844,13 @@ def forward( # `pooled_output`: independent of the sequence length # `hidden_states`: mainly used for debugging and analysis, so keep the padding # `attentions`: mainly used for debugging and analysis, so keep the padding - output = output[0][:, :-padding_len], *output[1:] + outputs = outputs[0][:, :-padding_len], *outputs[1:] - return output + return outputs @add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) -class LongformerForMaskedLM(BertPreTrainedModel): +class LongformerForMaskedLM(LongformerPreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" From 18519d61551ede1317a646d49380efe50b6c14ea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:31:21 +0000 Subject: [PATCH 02/19] fix longformer --- src/transformers/modeling_longformer.py | 54 +++++++++++++++---------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 4daed3c27bd7..916bd0d583b2 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -117,11 +117,11 @@ def __init__(self, config, layer_id): self.one_sided_attention_window_size = attention_window // 2 @staticmethod - def _skew(x, direction): + def _skew(hidden_states, direction): """Convert diagonals into columns (or columns into diagonals depending on `direction`""" - x_padded = F.pad(x, direction) # padding value is not important because it will be overwritten - x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2)) - return x_padded + hidden_states_padded = F.pad(hidden_states, direction) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.view(*hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)) + return hidden_states_padded @staticmethod def _skew2(x): @@ -238,7 +238,10 @@ def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int) return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2) def forward( - self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, + self, + hidden_states, + attention_mask=None, + output_attentions=False, ): """ LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. @@ -452,9 +455,14 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) def forward( - self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, + self, + hidden_states, + attention_mask=None, + output_attentions=False, ): - self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions,) + self_outputs = self.self( + hidden_states, attention_mask, output_attentions, + ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -468,10 +476,13 @@ def __init__(self, config, layer_id=0): self.output = BertOutput(config) def forward( - self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, + self, + hidden_states, + attention_mask=None, + output_attentions=False, ): self_attention_outputs = self.attention( - hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + hidden_states, attention_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights @@ -489,7 +500,11 @@ def __init__(self, config): self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) def forward( - self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, ): all_hidden_states = () all_attentions = () @@ -506,10 +521,16 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, attention_mask, head_mask[i], + create_custom_forward(layer_module), + hidden_states, + attention_mask, ) else: - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions,) + layer_outputs = layer_module( + hidden_states, + attention_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -715,7 +736,6 @@ def forward( global_attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -812,13 +832,6 @@ def forward( # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) @@ -826,7 +839,6 @@ def forward( encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, - head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) From 16d956ea04b7aaf888f489c109505ae30d3f0fe3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:31:51 +0000 Subject: [PATCH 03/19] fix longformer --- src/transformers/modeling_longformer.py | 137 ++++++++++++------------ 1 file changed, 68 insertions(+), 69 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 916bd0d583b2..a10e79677bf8 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -119,8 +119,12 @@ def __init__(self, config, layer_id): @staticmethod def _skew(hidden_states, direction): """Convert diagonals into columns (or columns into diagonals depending on `direction`""" - hidden_states_padded = F.pad(hidden_states, direction) # padding value is not important because it will be overwritten - hidden_states_padded = hidden_states_padded.view(*hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)) + hidden_states_padded = F.pad( + hidden_states, direction + ) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.view( + *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) + ) return hidden_states_padded @staticmethod @@ -136,82 +140,100 @@ def _skew2(x): return x @staticmethod - def _chunk(x, w): + def _chunk(hidden_states, window_overlap): """convert into overlapping chunkings. Chunk size = 2w, overlap size = w""" # non-overlapping chunks of size = 2w - x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2)) + hidden_states = hidden_states.view( + hidden_states.size(0), + hidden_states.size(1) // (window_overlap * 2), + window_overlap * 2, + hidden_states.size(2), + ) # use `as_strided` to make the chunks overlap with an overlap size = w - chunk_size = list(x.size()) + chunk_size = list(hidden_states.size()) chunk_size[1] = chunk_size[1] * 2 - 1 - chunk_stride = list(x.stride()) + chunk_stride = list(hidden_states.stride()) chunk_stride[1] = chunk_stride[1] // 2 - return x.as_strided(size=chunk_size, stride=chunk_stride) + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - def _mask_invalid_locations(self, input_tensor, w) -> torch.Tensor: - affected_seqlen = w - beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0]) + def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor: + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) beginning_mask = beginning_mask_2d[None, :, None, :] ending_mask = beginning_mask.flip(dims=(1, 3)) - beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1] + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_mask = beginning_mask.expand(beginning_input.size()) beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 - ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :] + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_mask = ending_mask.expand(ending_input.size()) ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 - def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int): + def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): """Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) - with an overlap of size w""" - batch_size, seqlen, num_heads, head_dim = q.size() - assert seqlen % (w * 2) == 0, f"Sequence length should be multiple of {w * 2}. Given {seqlen}" - assert q.size() == k.size() + with an overlap of size window_overlap""" + batch_size, seqlen, num_heads, head_dim = query.size() + assert ( + seqlen % (window_overlap * 2) == 0 + ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seqlen}" + assert query.size() == key.size() - chunks_count = seqlen // w - 1 + chunks_count = seqlen // window_overlap - 1 - # group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2 - q = q.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) - k = k.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) + # group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) - chunk_q = self._chunk(q, w) - chunk_k = self._chunk(k, w) + chunked_query = self._chunk(query, window_overlap) + chunked_key = self._chunk(key, window_overlap) # matrix multipication # bcxd: batch_size * num_heads x chunks x 2w x head_dim # bcyd: batch_size * num_heads x chunks x 2w x head_dim # bcxy: batch_size * num_heads x chunks x 2w x 2w - chunk_attn = torch.einsum("bcxd,bcyd->bcxy", (chunk_q, chunk_k)) # multiply + chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply # convert diagonals into columns - diagonal_chunk_attn = self._skew(chunk_attn, direction=(0, 0, 0, 1)) + diagonal_chunked_attention_scores = self._skew(chunked_attention_scores, direction=(0, 0, 0, 1)) # allocate space for the overall attention matrix where the chunks are compined. The last dimension # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to # w previous words). The following column is attention score from each word to itself, then # followed by w columns for the upper triangle. - diagonal_attn = diagonal_chunk_attn.new_empty((batch_size * num_heads, chunks_count + 1, w, w * 2 + 1)) + diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) - # copy parts from diagonal_chunk_attn into the compined matrix of attentions + # copy parts from diagonal_chunked_attention_scores into the compined matrix of attentions # - copying the main diagonal and the upper triangle - diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, : w + 1] - diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, : w + 1] + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] # - copying the lower triangle - diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, -(w + 1) : -1, w + 1 :] - diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, : w - 1, 1 - w :] + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] # separate batch_size and num_heads dimensions again - diagonal_attn = diagonal_attn.view(batch_size, num_heads, seqlen, 2 * w + 1).transpose(2, 1) + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seqlen, 2 * window_overlap + 1 + ).transpose(2, 1) - self._mask_invalid_locations(diagonal_attn, w) - return diagonal_attn + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int): - """Same as _sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output - format from _sliding_chunks_matmul_qk""" + """Same as _sliding_chunks_query_key_matmul but for prob and value tensors. It is expecting the same output + format from _sliding_chunks_query_key_matmul""" batch_size, seqlen, num_heads, head_dim = v.size() assert seqlen % (w * 2) == 0 assert prob.size()[:3] == v.size()[:3] @@ -238,10 +260,7 @@ def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int) return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2) def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, + self, hidden_states, attention_mask=None, output_attentions=False, ): """ LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. @@ -295,7 +314,7 @@ def forward( q = q.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights = (batch_size, seqlen, num_heads, window*2+1) - attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size) + attn_weights = self._sliding_chunks_query_key_matmul(q, k, self.one_sided_attention_window_size) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size) @@ -308,7 +327,7 @@ def forward( ) ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding - d_mask = self._sliding_chunks_matmul_qk(ones, float_mask, self.one_sided_attention_window_size) + d_mask = self._sliding_chunks_query_key_matmul(ones, float_mask, self.one_sided_attention_window_size) attn_weights += d_mask assert list(attn_weights.size()) == [ batch_size, @@ -455,14 +474,9 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, + self, hidden_states, attention_mask=None, output_attentions=False, ): - self_outputs = self.self( - hidden_states, attention_mask, output_attentions, - ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions,) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -476,14 +490,9 @@ def __init__(self, config, layer_id=0): self.output = BertOutput(config) def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, + self, hidden_states, attention_mask=None, output_attentions=False, ): - self_attention_outputs = self.attention( - hidden_states, attention_mask, output_attentions=output_attentions, - ) + self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights @@ -500,11 +509,7 @@ def __init__(self, config): self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, + self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, ): all_hidden_states = () all_attentions = () @@ -521,16 +526,10 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, + create_custom_forward(layer_module), hidden_states, attention_mask, ) else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - output_attentions, - ) + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,) hidden_states = layer_outputs[0] if output_attentions: From 01d3d841d45702ed7998ed385a2f4a9c2e6e3405 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 13:53:24 +0200 Subject: [PATCH 04/19] refactor naming --- src/transformers/modeling_longformer.py | 125 ++++++++++++------------ 1 file changed, 62 insertions(+), 63 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index a10e79677bf8..bfa3800cbf82 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -117,10 +117,10 @@ def __init__(self, config, layer_id): self.one_sided_attention_window_size = attention_window // 2 @staticmethod - def _skew(hidden_states, direction): - """Convert diagonals into columns (or columns into diagonals depending on `direction`""" + def pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """Convert diagonals into columns (or columns into diagonals depending on `padding`""" hidden_states_padded = F.pad( - hidden_states, direction + hidden_states_padded, padding ) # padding value is not important because it will be overwritten hidden_states_padded = hidden_states_padded.view( *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) @@ -128,16 +128,15 @@ def _skew(hidden_states, direction): return hidden_states_padded @staticmethod - def _skew2(x): + def _value_attention_probs_skew(chunked_hidden_states): """shift every row 1 step to right converting columns into diagonals""" - # X = B x C x M x L - B, C, M, L = x.size() - x = F.pad(x, (0, M + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten - x = x.view(B, C, -1) # B x C x ML+MM+M - x = x[:, :, :-M] # B x C x ML+MM - x = x.view(B, C, M, M + L) # B x C, M x L+M - x = x[:, :, :, :-1] - return x + total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() + chunked_hidden_states = F.pad(chunked_hidden_states, (0, window_overlap + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, -1) # B x C x ML+MM+M + chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap] # B x C x ML+MM + chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) # B x C, M x L+M + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states @staticmethod def _chunk(hidden_states, window_overlap): @@ -174,17 +173,17 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso """Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an overlap of size window_overlap""" - batch_size, seqlen, num_heads, head_dim = query.size() + batch_size, seq_len, num_heads, head_dim = query.size() assert ( - seqlen % (window_overlap * 2) == 0 - ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seqlen}" + seq_len % (window_overlap * 2) == 0 + ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" assert query.size() == key.size() - chunks_count = seqlen // window_overlap - 1 + chunks_count = seq_len // window_overlap - 1 - # group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size window_overlap * 2 - query = query.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) - key = key.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) chunked_query = self._chunk(query, window_overlap) chunked_key = self._chunk(key, window_overlap) @@ -196,7 +195,7 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply # convert diagonals into columns - diagonal_chunked_attention_scores = self._skew(chunked_attention_scores, direction=(0, 0, 0, 1)) + diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims(chunked_attention_scores, padding=(0, 0, 0, 1)) # allocate space for the overall attention matrix where the chunks are compined. The last dimension # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to @@ -225,39 +224,39 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso # separate batch_size and num_heads dimensions again diagonal_attention_scores = diagonal_attention_scores.view( - batch_size, num_heads, seqlen, 2 * window_overlap + 1 + batch_size, num_heads, seq_len, 2 * window_overlap + 1 ).transpose(2, 1) self._mask_invalid_locations(diagonal_attention_scores, window_overlap) return diagonal_attention_scores - def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int): - """Same as _sliding_chunks_query_key_matmul but for prob and value tensors. It is expecting the same output + def _sliding_chunks_matmul_attention_probs_value(self, attention_probs: torch.Tensor, value: torch.Tensor, window_overlap: int): + """Same as _sliding_chunks_query_key_matmul but for attention_probs and value tensors. It is expecting the same output format from _sliding_chunks_query_key_matmul""" - batch_size, seqlen, num_heads, head_dim = v.size() - assert seqlen % (w * 2) == 0 - assert prob.size()[:3] == v.size()[:3] - assert prob.size(3) == 2 * w + 1 - chunks_count = seqlen // w - 1 - # group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size 2w - chunk_prob = prob.transpose(1, 2).reshape(batch_size * num_heads, seqlen // w, w, 2 * w + 1) + batch_size, seq_len, num_heads, head_dim = value.size() + assert seq_len % (window_overlap * 2) == 0 + assert attention_probs.size()[:3] == value.size()[:3] + assert attention_probs.size(3) == 2 * window_overlap + 1 + chunks_count = seq_len // window_overlap - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + chunked_attention_probs = attention_probs.transpose(1, 2).reshape(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1) # group batch_size and num_heads dimensions into one - v = v.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) + value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - # pad seqlen with w at the beginning of the sequence and another w at the end - padded_v = F.pad(v, (0, 0, w, w), value=-1) + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) - # chunk padded_v into chunks of size 3w and an overlap of size w - chunk_v_size = (batch_size * num_heads, chunks_count + 1, 3 * w, head_dim) - chunk_v_stride = padded_v.stride() - chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2] - chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride) + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = chunked_value_stride[0], window_overlap * chunked_value_stride[1], chunked_value_stride[1], chunked_value_stride[2] + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) - skewed_prob = self._skew2(chunk_prob) + skewed_prob = self._value_attention_probs_skew(chunked_attention_probs) - context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunk_v)) - return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2) + context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) def forward( self, hidden_states, attention_mask=None, output_attentions=False, @@ -304,20 +303,20 @@ def forward( key_padding_mask = None hidden_states = hidden_states.transpose(0, 1) - seqlen, batch_size, embed_dim = hidden_states.size() + seq_len, batch_size, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) q /= math.sqrt(self.head_dim) - q = q.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - # attn_weights = (batch_size, seqlen, num_heads, window*2+1) + q = q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + # attn_weights = (batch_size, seq_len, num_heads, window*2+1) attn_weights = self._sliding_chunks_query_key_matmul(q, k, self.one_sided_attention_window_size) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 - # from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size) + # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( dim=-1 ) @@ -331,7 +330,7 @@ def forward( attn_weights += d_mask assert list(attn_weights.size()) == [ batch_size, - seqlen, + seq_len, self.num_heads, self.one_sided_attention_window_size * 2 + 1, ] @@ -340,11 +339,11 @@ def forward( if extra_attention_mask is not None: selected_k = k.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] - # (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch) + # (batch_size, seq_len, num_heads, max_num_extra_indices_per_batch) selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k)) selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 # concat to attn_weights - # (batch_size, seqlen, num_heads, extra attention count + 2*window+1) + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_fp32 = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability @@ -355,7 +354,7 @@ def forward( attn_weights = torch.masked_fill(attn_weights, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) - v = v.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + v = v.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) attn = None if extra_attention_mask is not None: selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) @@ -368,12 +367,12 @@ def forward( -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch ).contiguous() if attn is None: - attn = self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size) + attn = self._sliding_chunks_matmul_attention_probs_value(attn_probs, v, self.one_sided_attention_window_size) else: - attn += self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size) + attn += self._sliding_chunks_matmul_attention_probs_value(attn_probs, v, self.one_sided_attention_window_size) - assert attn.size() == (batch_size, seqlen, self.num_heads, self.head_dim), "Unexpected size" - attn = attn.transpose(0, 1).reshape(seqlen, batch_size, embed_dim).contiguous() + assert attn.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn = attn.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. @@ -396,18 +395,18 @@ def forward( ) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim) k = ( k.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seqlen, head_dim) + ) # batch_size * self.num_heads, seq_len, head_dim) v = ( v.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seqlen, head_dim) + ) # batch_size * self.num_heads, seq_len, head_dim) attn_weights = torch.bmm(q, k.transpose(1, 2)) - assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen] + assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len] - attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen) + attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 if key_padding_mask is not None: attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) - attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen) + attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights_float = F.softmax( attn_weights, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability @@ -438,7 +437,7 @@ def forward( # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attn_weights are padded with -10000.0 attention scores - attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen) + attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size @@ -698,13 +697,13 @@ def _pad_to_window_size( assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape - batch_size, seqlen = input_shape[:2] + batch_size, seq_len = input_shape[:2] - padding_len = (attention_window - seqlen % attention_window) % attention_window + padding_len = (attention_window - seq_len % attention_window) % attention_window if padding_len > 0: logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( - seqlen, seqlen + padding_len, attention_window + seq_len, seq_len + padding_len, attention_window ) ) if input_ids is not None: From d01fdf7fd5032d1f3f26104c0c5b071fba788600 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 14:17:35 +0200 Subject: [PATCH 05/19] add small slow test --- tests/test_modeling_longformer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 7579ee38ba0d..c72406bb762c 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -325,6 +325,18 @@ def test_inference_no_head(self): model = LongformerModel.from_pretrained("allenai/longformer-base-4096") model.to(torch_device) + # 'Hello world!' + input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device) + output = model(input_ids)[0] + + expected_output_slice = torch.tensor([0.0105, 0.1807, -0.2062, -0.0643, 0.0629], device=torch_device) + self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4)) + + @slow + def test_inference_no_head_long(self): + model = LongformerModel.from_pretrained("allenai/longformer-base-4096") + model.to(torch_device) + # 'Hello world! ' repeated 1000 times input_ids = torch.tensor( [[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device @@ -341,7 +353,7 @@ def test_inference_no_head(self): self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4)) @slow - def test_inference_masked_lm(self): + def test_inference_masked_lm_long(self): model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096") model.to(torch_device) From 18185a5eb1ae4e3be8f9a45b79212ded1a5f6a02 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 15:25:19 +0200 Subject: [PATCH 06/19] refactor --- src/transformers/modeling_longformer.py | 62 +++++++++++++------------ 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index bfa3800cbf82..1e4a259aea49 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -652,6 +652,7 @@ class LongformerModel(LongformerPreTrainedModel): def __init__(self, config): super().__init__(config) + self.config = config if isinstance(config.attention_window, int): assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" @@ -690,10 +691,15 @@ def _pad_to_window_size( token_type_ids: torch.Tensor, position_ids: torch.Tensor, inputs_embeds: torch.Tensor, - attention_window: int, pad_token_id: int, ): """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + # padding + attention_window = ( + self.config.attention_window + if isinstance(self.config.attention_window, int) + else max(self.config.attention_window) + ) assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape @@ -726,6 +732,18 @@ def _pad_to_window_size( return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + def _merge_to_attention_mask(attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, @@ -781,6 +799,19 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -797,35 +828,6 @@ def forward( if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # padding - attention_window = ( - self.config.attention_window - if isinstance(self.config.attention_window, int) - else max(self.config.attention_window) - ) - - # merge `global_attention_mask` and `attention_mask` - if global_attention_mask is not None: - # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) - # (global_attention_mask + 1) => 1 for local attention, 2 for global attention - # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention - if attention_mask is not None: - attention_mask = attention_mask * (global_attention_mask + 1) - else: - # simply use `global_attention_mask` as `attention_mask` - # if no `attention_mask` is given - attention_mask = global_attention_mask + 1 - - padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - attention_window=attention_window, - pad_token_id=self.config.pad_token_id, - ) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) From b7c2cd0b14fdb103bee13d41779e98f9512d7fe3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 16:31:17 +0200 Subject: [PATCH 07/19] refactor naming --- src/transformers/modeling_longformer.py | 213 ++++++++++++------------ 1 file changed, 108 insertions(+), 105 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 1e4a259aea49..ecef3e155ccc 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -118,7 +118,7 @@ def __init__(self, config, layer_id): @staticmethod def pad_and_transpose_last_two_dims(hidden_states_padded, padding): - """Convert diagonals into columns (or columns into diagonals depending on `padding`""" + """Convert diagonals into columns or columns into diagonals depending on `padding`""" hidden_states_padded = F.pad( hidden_states_padded, padding ) # padding value is not important because it will be overwritten @@ -128,7 +128,7 @@ def pad_and_transpose_last_two_dims(hidden_states_padded, padding): return hidden_states_padded @staticmethod - def _value_attention_probs_skew(chunked_hidden_states): + def _pad_by_window_overlap_except_last_row(chunked_hidden_states): """shift every row 1 step to right converting columns into diagonals""" total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() chunked_hidden_states = F.pad(chunked_hidden_states, (0, window_overlap + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten @@ -253,9 +253,9 @@ def _sliding_chunks_matmul_attention_probs_value(self, attention_probs: torch.Te chunked_value_stride = chunked_value_stride[0], window_overlap * chunked_value_stride[1], chunked_value_stride[1], chunked_value_stride[2] chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) - skewed_prob = self._value_attention_probs_skew(chunked_attention_probs) + chunked_attention_probs = self._pad_by_window_overlap_except_last_row(chunked_attention_probs) - context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunked_value)) + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attention_probs, chunked_value)) return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) def forward( @@ -272,63 +272,60 @@ def forward( """ - if attention_mask is not None: - attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) - key_padding_mask = attention_mask < 0 - extra_attention_mask = attention_mask > 0 - remove_from_windowed_attention_mask = attention_mask != 0 - - num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) - max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() - if max_num_extra_indices_per_batch <= 0: - extra_attention_mask = None - else: - # To support the case of variable number of global attention in the rows of a batch, - # we use the following three selection masks to select global attention embeddings - # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` - # 1) selecting embeddings that correspond to global attention - extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) - zero_to_max_range = torch.arange( - 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device - ) - # mask indicating which values are actually going to be padding - selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) - # 2) location of the non-padding values in the selected global attention - selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) - # 3) location of the padding values in the selected global attention - selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) - else: - remove_from_windowed_attention_mask = None - extra_attention_mask = None - key_padding_mask = None + attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) + key_padding_mask = attention_mask < 0 + extra_attention_mask = attention_mask > 0 + + num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) + max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() + + if max_num_extra_indices_per_batch > 0: + # To support the case of variable number of global attention in the rows of a batch, + # we use the following three selection masks to select global attention embeddings + # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` + # 1) selecting embeddings that correspond to global attention + extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) + zero_to_max_range = torch.arange( + 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device + ) + # mask indicating which values are actually going to be padding + selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) + # 2) location of the non-padding values in the selected global attention + selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) + # 3) location of the padding values in the selected global attention + selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) hidden_states = hidden_states.transpose(0, 1) seq_len, batch_size, embed_dim = hidden_states.size() - assert embed_dim == self.embed_dim - q = self.query(hidden_states) - k = self.key(hidden_states) - v = self.value(hidden_states) - q /= math.sqrt(self.head_dim) - - q = q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - k = k.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - # attn_weights = (batch_size, seq_len, num_heads, window*2+1) - attn_weights = self._sliding_chunks_query_key_matmul(q, k, self.one_sided_attention_window_size) - if remove_from_windowed_attention_mask is not None: - # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 - # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) - remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( - dim=-1 - ) - # cast to fp32/fp16 then replace 1's with -inf - float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill( - remove_from_windowed_attention_mask, -10000.0 - ) - ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones - # diagonal mask with zeros everywhere and -inf inplace of padding - d_mask = self._sliding_chunks_query_key_matmul(ones, float_mask, self.one_sided_attention_window_size) - attn_weights += d_mask - assert list(attn_weights.size()) == [ + assert embed_dim == self.embed_dim, f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # attention_probs = (batch_size, seq_len, num_heads, window*2+1) + attention_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attention_window_size) + + remove_from_windowed_attention_mask = attention_mask != 0 + # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 + # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) + remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( + dim=-1 + ) + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, -10000.0 + ) + ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones + # diagonal mask with zeros everywhere and -inf inplace of padding + d_mask = self._sliding_chunks_query_key_matmul(ones, float_mask, self.one_sided_attention_window_size) + attention_probs += d_mask + + assert list(attention_probs.size()) == [ batch_size, seq_len, self.num_heads, @@ -336,30 +333,30 @@ def forward( ] # the extra attention - if extra_attention_mask is not None: - selected_k = k.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] + if max_num_extra_indices_per_batch > 0: + selected_k = key_vectors.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + selected_k[selection_padding_mask_nonzeros] = key_vectors[extra_attention_mask_nonzeros] # (batch_size, seq_len, num_heads, max_num_extra_indices_per_batch) - selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k)) - selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 - # concat to attn_weights + selected_attention_probs = torch.einsum("blhd,bshd->blhs", (query_vectors, selected_k)) + selected_attention_probs[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 + # concat to attention_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) + attention_probs = torch.cat((selected_attention_probs, attention_probs), dim=-1) - attn_weights_fp32 = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attn_weights = attn_weights_fp32.type_as(attn_weights) + attention_probs_fp32 = F.softmax(attention_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attention_probs = attention_probs_fp32.type_as(attention_probs) - if key_padding_mask is not None: - # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attn_weights = torch.masked_fill(attn_weights, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attention_probs = torch.masked_fill(attention_probs, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) - attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) - v = v.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + attn_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) attn = None - if extra_attention_mask is not None: + + if max_num_extra_indices_per_batch > 0: selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) - selected_v = v.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] + selected_v = value_vectors.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + selected_v[selection_padding_mask_nonzeros] = value_vectors[extra_attention_mask_nonzeros] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2) @@ -367,9 +364,9 @@ def forward( -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch ).contiguous() if attn is None: - attn = self._sliding_chunks_matmul_attention_probs_value(attn_probs, v, self.one_sided_attention_window_size) + attn = self._sliding_chunks_matmul_attention_probs_value(attn_probs, value_vectors, self.one_sided_attention_window_size) else: - attn += self._sliding_chunks_matmul_attention_probs_value(attn_probs, v, self.one_sided_attention_window_size) + attn += self._sliding_chunks_matmul_attention_probs_value(attn_probs, value_vectors, self.one_sided_attention_window_size) assert attn.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" attn = attn.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() @@ -377,41 +374,45 @@ def forward( # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. # TODO: remove the redundant computation - if extra_attention_mask is not None: + if max_num_extra_indices_per_batch > 0: selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, batch_size, embed_dim) selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[ extra_attention_mask_nonzeros[::-1] ] - q = self.query_global(selected_hidden_states) - k = self.key_global(hidden_states) - v = self.value_global(hidden_states) - q /= math.sqrt(self.head_dim) + global_query_vectors = self.query_global(selected_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) - q = ( - q.contiguous() + global_query_vectors /= math.sqrt(self.head_dim) + + global_query_vectors = ( + global_query_vectors.contiguous() .view(max_num_extra_indices_per_batch, batch_size * self.num_heads, self.head_dim) .transpose(0, 1) ) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim) - k = ( - k.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) ) # batch_size * self.num_heads, seq_len, head_dim) - v = ( - v.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) ) # batch_size * self.num_heads, seq_len, head_dim) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len] - - attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) - attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 - if key_padding_mask is not None: - attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) - attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len) - attn_weights_float = F.softmax( - attn_weights, dim=-1, dtype=torch.float32 + attention_probs = torch.bmm(global_query_vectors, global_key_vectors.transpose(1, 2)) + assert list(attention_probs.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len] + + attention_probs = attention_probs.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) + attention_probs[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 + + attention_probs = attention_probs.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) + + attention_probs = attention_probs.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len) + attention_probs_float = F.softmax( + attention_probs, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability - attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) - selected_attn = torch.bmm(attn_probs, v) + + attention_probs = F.dropout(attention_probs_float.type_as(attention_probs), p=self.dropout, training=self.training) + selected_attn = torch.bmm(attention_probs, global_key_vectors) + assert list(selected_attn.size()) == [ batch_size * self.num_heads, max_num_extra_indices_per_batch, @@ -429,6 +430,7 @@ def forward( ) context_layer = attn.transpose(0, 1) + if output_attentions: if extra_attention_mask is not None: # With global attention, return global attention probabilities only @@ -436,14 +438,15 @@ def forward( # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, - # attn_weights are padded with -10000.0 attention scores - attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) + # attention_probs are padded with -10000.0 attention scores + attention_probs = attention_probs.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours - attn_weights = attn_weights.permute(0, 2, 1, 3) - outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) + attention_probs = attention_probs.permute(0, 2, 1, 3) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs @@ -732,7 +735,7 @@ def _pad_to_window_size( return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds - def _merge_to_attention_mask(attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): + def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) # (global_attention_mask + 1) => 1 for local attention, 2 for global attention # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention From 6bb7379e9b9c90b81a208c846d05b9cdea952035 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 16:39:48 +0200 Subject: [PATCH 08/19] rename selected to extra --- src/transformers/modeling_longformer.py | 29 +++++++++++++++---------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index ecef3e155ccc..c36c399ba64a 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -310,42 +310,49 @@ def forward( # attention_probs = (batch_size, seq_len, num_heads, window*2+1) attention_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attention_window_size) - remove_from_windowed_attention_mask = attention_mask != 0 # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) - remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( + remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze( dim=-1 ) + # cast to fp32/fp16 then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( remove_from_windowed_attention_mask, -10000.0 ) - ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding - d_mask = self._sliding_chunks_query_key_matmul(ones, float_mask, self.one_sided_attention_window_size) - attention_probs += d_mask + diagonal_mask = self._sliding_chunks_query_key_matmul(float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attention_window_size) + attention_probs += diagonal_mask assert list(attention_probs.size()) == [ batch_size, seq_len, self.num_heads, self.one_sided_attention_window_size * 2 + 1, - ] + ], f"attention_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attention_window_size * 2 + 1}), but is of size {attention_probs.size()}" # the extra attention if max_num_extra_indices_per_batch > 0: - selected_k = key_vectors.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - selected_k[selection_padding_mask_nonzeros] = key_vectors[extra_attention_mask_nonzeros] + extra_key_vectors = key_vectors.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + extra_key_vectors[selection_padding_mask_nonzeros] = key_vectors[extra_attention_mask_nonzeros] # (batch_size, seq_len, num_heads, max_num_extra_indices_per_batch) - selected_attention_probs = torch.einsum("blhd,bshd->blhs", (query_vectors, selected_k)) - selected_attention_probs[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 + + extra_attention_probs = torch.einsum("blhd,bshd->blhs", (query_vectors, extra_key_vectors)) + extra_attention_probs[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 + # concat to attention_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attention_probs = torch.cat((selected_attention_probs, attention_probs), dim=-1) + attention_probs = torch.cat((extra_attention_probs, attention_probs), dim=-1) + + # free memory + del key_vectors, query_vectors, extra_key_vectors, extra_attention_probs attention_probs_fp32 = F.softmax(attention_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attention_probs = attention_probs_fp32.type_as(attention_probs) + # free memory + del attention_probs_fp32 + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attention_probs = torch.masked_fill(attention_probs, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) From 0c29a39add5fc697aa72d75c80da8c6421590429 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 18:37:58 +0200 Subject: [PATCH 09/19] big global attention refactor --- src/transformers/modeling_longformer.py | 197 +++++++++++++----------- tests/test_modeling_longformer.py | 27 ++++ 2 files changed, 138 insertions(+), 86 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index c36c399ba64a..a6608b75caf8 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -273,44 +273,53 @@ def forward( """ attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) - key_padding_mask = attention_mask < 0 - extra_attention_mask = attention_mask > 0 - num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) - max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() + # key tokens to be padded + is_index_masked = attention_mask < 0 - if max_num_extra_indices_per_batch > 0: + # all global attention tokens + is_index_global_attention = attention_mask > 0 + + # how many global attention tokens + num_global_attn_indices_per_batch = is_index_global_attention.long().sum(dim=1) + + # max global attention tokens of all batches + max_num_global_attn_indices_of_batches = num_global_attn_indices_per_batch.max() + + if max_num_global_attn_indices_of_batches > 0: # To support the case of variable number of global attention in the rows of a batch, # we use the following three selection masks to select global attention embeddings - # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` + # in a 3d tensor and pad it to `max_num_global_attn_indices_of_batches` # 1) selecting embeddings that correspond to global attention - extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) - zero_to_max_range = torch.arange( - 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device - ) - # mask indicating which values are actually going to be padding - selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) + is_index_global_attention_nonzeros = is_index_global_attention.nonzero(as_tuple=True) + + # mask indicating which values are actually going to be padded for global attention computation + is_local_index_global_attention = torch.arange(max_num_global_attn_indices_of_batches, device=attention_mask.device) < num_global_attn_indices_per_batch.unsqueeze(dim=-1) + # 2) location of the non-padding values in the selected global attention - selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) + is_local_index_global_attention_indices = is_local_index_global_attention.nonzero(as_tuple=True) # 3) location of the padding values in the selected global attention - selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) + local_index_no_global_attention_indices = (is_local_index_global_attention == 0).nonzero(as_tuple=True) hidden_states = hidden_states.transpose(0, 1) - seq_len, batch_size, embed_dim = hidden_states.size() - assert embed_dim == self.embed_dim, f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + # project hidden states query_vectors = self.query(hidden_states) key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert embed_dim == self.embed_dim, f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + + # normalize query query_vectors /= math.sqrt(self.head_dim) query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - # attention_probs = (batch_size, seq_len, num_heads, window*2+1) - attention_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attention_window_size) + # local_attention_probs = (batch_size, seq_len, num_heads, window*2+1) + local_attention_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attention_window_size) - # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze( dim=-1 @@ -322,138 +331,154 @@ def forward( ) # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul(float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attention_window_size) - attention_probs += diagonal_mask - assert list(attention_probs.size()) == [ + # pad local attention probs + local_attention_probs += diagonal_mask + + assert list(local_attention_probs.size()) == [ batch_size, seq_len, self.num_heads, self.one_sided_attention_window_size * 2 + 1, - ], f"attention_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attention_window_size * 2 + 1}), but is of size {attention_probs.size()}" + ], f"local_attention_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attention_window_size * 2 + 1}), but is of size {local_attention_probs.size()}" - # the extra attention - if max_num_extra_indices_per_batch > 0: - extra_key_vectors = key_vectors.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - extra_key_vectors[selection_padding_mask_nonzeros] = key_vectors[extra_attention_mask_nonzeros] - # (batch_size, seq_len, num_heads, max_num_extra_indices_per_batch) + # compute local attention probs from global attention keys and contact over window dim + if max_num_global_attn_indices_of_batches > 0: + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) + key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[is_index_global_attention_nonzeros] - extra_attention_probs = torch.einsum("blhd,bshd->blhs", (query_vectors, extra_key_vectors)) - extra_attention_probs[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 + # (batch_size, seq_len, num_heads, max_num_global_attn_indices_of_batches) + attention_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + attention_probs_from_global_key[local_index_no_global_attention_indices[0], :, :, local_index_no_global_attention_indices[1]] = -10000.0 - # concat to attention_probs + # concat to local_attention_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attention_probs = torch.cat((extra_attention_probs, attention_probs), dim=-1) + local_attention_probs = torch.cat((attention_probs_from_global_key, local_attention_probs), dim=-1) # free memory - del key_vectors, query_vectors, extra_key_vectors, extra_attention_probs + del key_vectors, query_vectors, key_vectors_only_global, attention_probs_from_global_key - attention_probs_fp32 = F.softmax(attention_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attention_probs = attention_probs_fp32.type_as(attention_probs) + local_attention_probs_fp32 = F.softmax(local_attention_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + local_attention_probs = local_attention_probs_fp32.type_as(local_attention_probs) # free memory - del attention_probs_fp32 + del local_attention_probs_fp32 # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attention_probs = torch.masked_fill(attention_probs, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) + local_attention_probs = torch.masked_fill(local_attention_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) - attn_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) + local_attention_probs = F.dropout(local_attention_probs, p=self.dropout, training=self.training) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - attn = None - if max_num_extra_indices_per_batch > 0: - selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) - selected_v = value_vectors.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - selected_v[selection_padding_mask_nonzeros] = value_vectors[extra_attention_mask_nonzeros] + # compute local attention output with global attention value and add + if max_num_global_attn_indices_of_batches > 0: + local_attn_probs_only_global = local_attention_probs.narrow(-1, 0, max_num_global_attn_indices_of_batches) + +# only_global_value_vectors = value_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) + + value_vectors_only_global = value_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) + value_vectors_only_global[is_local_index_global_attention_indices] = value_vectors[is_index_global_attention_nonzeros] + # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2) - attn_probs = attn_probs.narrow( - -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch + local_attention_output_only_global = torch.matmul(local_attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2)).transpose(1, 2) + + local_attn_probs_without_global = local_attention_probs.narrow( + -1, max_num_global_attn_indices_of_batches, local_attention_probs.size(-1) - max_num_global_attn_indices_of_batches ).contiguous() - if attn is None: - attn = self._sliding_chunks_matmul_attention_probs_value(attn_probs, value_vectors, self.one_sided_attention_window_size) + # add computed attention output + attention_output = local_attention_output_only_global + self._sliding_chunks_matmul_attention_probs_value(local_attn_probs_without_global, value_vectors, self.one_sided_attention_window_size) + + # free memory + del local_attn_probs_only_global, value_vectors_only_global, local_attention_output_only_global, local_attn_probs_without_global else: - attn += self._sliding_chunks_matmul_attention_probs_value(attn_probs, value_vectors, self.one_sided_attention_window_size) + # compute local attention + attention_output = self._sliding_chunks_matmul_attention_probs_value(local_attention_probs, value_vectors, self.one_sided_attention_window_size) - assert attn.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" - attn = attn.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + assert attention_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attention_output = attention_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() - # For this case, we'll just recompute the attention for these indices - # and overwrite the attn tensor. + # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation - if max_num_extra_indices_per_batch > 0: - selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, batch_size, embed_dim) - selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[ - extra_attention_mask_nonzeros[::-1] + if max_num_global_attn_indices_of_batches > 0: + only_global_attention_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices_of_batches, batch_size, embed_dim) + only_global_attention_hidden_states[is_local_index_global_attention_indices[::-1]] = hidden_states[ + is_index_global_attention_nonzeros[::-1] ] - global_query_vectors = self.query_global(selected_hidden_states) + only_global_query_vectors = self.query_global(only_global_attention_hidden_states) global_key_vectors = self.key_global(hidden_states) global_value_vectors = self.value_global(hidden_states) - global_query_vectors /= math.sqrt(self.head_dim) + # normalize + only_global_query_vectors /= math.sqrt(self.head_dim) - global_query_vectors = ( - global_query_vectors.contiguous() - .view(max_num_extra_indices_per_batch, batch_size * self.num_heads, self.head_dim) + only_global_query_vectors = ( + only_global_query_vectors.contiguous() + .view(max_num_global_attn_indices_of_batches, batch_size * self.num_heads, self.head_dim) .transpose(0, 1) - ) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim) + ) # (batch_size * self.num_heads, max_num_global_attn_indices_of_batches, head_dim) global_key_vectors = ( global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) ) # batch_size * self.num_heads, seq_len, head_dim) global_value_vectors = ( global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) ) # batch_size * self.num_heads, seq_len, head_dim) - attention_probs = torch.bmm(global_query_vectors, global_key_vectors.transpose(1, 2)) - assert list(attention_probs.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len] + global_attention_probs = torch.bmm(only_global_query_vectors, global_key_vectors.transpose(1, 2)) + assert list(global_attention_probs.size()) == [batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len], f"global_attention_probs have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len)}, but is {global_attention_probs.size()}." + + global_attention_probs = global_attention_probs.view(batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len) - attention_probs = attention_probs.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) - attention_probs[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 + global_attention_probs[local_index_no_global_attention_indices[0], :, local_index_no_global_attention_indices[1], :] = -10000.0 - attention_probs = attention_probs.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) + global_attention_probs = global_attention_probs.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) - attention_probs = attention_probs.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seq_len) - attention_probs_float = F.softmax( - attention_probs, dim=-1, dtype=torch.float32 + global_attention_probs = global_attention_probs.view(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len) + global_attention_probs_float = F.softmax( + global_attention_probs, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability - attention_probs = F.dropout(attention_probs_float.type_as(attention_probs), p=self.dropout, training=self.training) - selected_attn = torch.bmm(attention_probs, global_key_vectors) + global_attention_probs = F.dropout(global_attention_probs_float.type_as(global_attention_probs), p=self.dropout, training=self.training) + + global_attention_output = torch.bmm(global_attention_probs, global_key_vectors) - assert list(selected_attn.size()) == [ + assert list(global_attention_output.size()) == [ batch_size * self.num_heads, - max_num_extra_indices_per_batch, + max_num_global_attn_indices_of_batches, self.head_dim, - ] + ], f"global_attention_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, self.head_dim)}, but is {global_attention_output.size()}." - selected_attn_4d = selected_attn.view( - batch_size, self.num_heads, max_num_extra_indices_per_batch, self.head_dim + global_attention_output = global_attention_output.view( + batch_size, self.num_heads, max_num_global_attn_indices_of_batches, self.head_dim ) - nonzero_selected_attn = selected_attn_4d[ - selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1] + nonzero_global_attention_output = global_attention_output[ + is_local_index_global_attention_indices[0], :, is_local_index_global_attention_indices[1] ] - attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( - len(selection_padding_mask_nonzeros[0]), -1 + + # overwrite values with global attention + attention_output[is_index_global_attention_nonzeros[::-1]] = nonzero_global_attention_output.view( + len(is_local_index_global_attention_indices[0]), -1 ) - context_layer = attn.transpose(0, 1) + attention_output = attention_output.transpose(0, 1) if output_attentions: - if extra_attention_mask is not None: + if max_num_global_attn_indices_of_batches > 0: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attention_probs are padded with -10000.0 attention scores - attention_probs = attention_probs.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seq_len) + local_attention_probs = local_attention_probs.view(batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours - attention_probs = attention_probs.permute(0, 2, 1, 3) + local_attention_probs = local_attention_probs.permute(0, 2, 1, 3) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + outputs = (attention_output, local_attention_probs) if output_attentions else (attention_output,) return outputs diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index c72406bb762c..0ff511decbc6 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -134,6 +134,29 @@ def create_and_check_longformer_model( ) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_longformer_model_with_global_attention_mask( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = LongformerModel(config=config) + model.to(torch_device) + model.eval() + global_attention_mask = input_mask.clone() + global_attention_mask[:, input_mask.shape[-1] // 2] = 0 + global_attention_mask = global_attention_mask.to(torch_device) + + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids) + sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask) + sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_longformer_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -298,6 +321,10 @@ def test_longformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_model(*config_and_inputs) + def test_longformer_model_global_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) + def test_longformer_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) From 3012770e161a2ba339fea29a52381193bb8caba4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 18:38:19 +0200 Subject: [PATCH 10/19] make style --- src/transformers/modeling_longformer.py | 134 ++++++++++++++++++------ tests/test_modeling_longformer.py | 11 +- 2 files changed, 109 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index a6608b75caf8..aceafff5d07e 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -131,10 +131,14 @@ def pad_and_transpose_last_two_dims(hidden_states_padded, padding): def _pad_by_window_overlap_except_last_row(chunked_hidden_states): """shift every row 1 step to right converting columns into diagonals""" total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() - chunked_hidden_states = F.pad(chunked_hidden_states, (0, window_overlap + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = F.pad( + chunked_hidden_states, (0, window_overlap + 1) + ) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, -1) # B x C x ML+MM+M chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap] # B x C x ML+MM - chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) # B x C, M x L+M + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) # B x C, M x L+M chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @@ -195,7 +199,9 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply # convert diagonals into columns - diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims(chunked_attention_scores, padding=(0, 0, 0, 1)) + diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims( + chunked_attention_scores, padding=(0, 0, 0, 1) + ) # allocate space for the overall attention matrix where the chunks are compined. The last dimension # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to @@ -230,7 +236,9 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso self._mask_invalid_locations(diagonal_attention_scores, window_overlap) return diagonal_attention_scores - def _sliding_chunks_matmul_attention_probs_value(self, attention_probs: torch.Tensor, value: torch.Tensor, window_overlap: int): + def _sliding_chunks_matmul_attention_probs_value( + self, attention_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + ): """Same as _sliding_chunks_query_key_matmul but for attention_probs and value tensors. It is expecting the same output format from _sliding_chunks_query_key_matmul""" batch_size, seq_len, num_heads, head_dim = value.size() @@ -239,7 +247,9 @@ def _sliding_chunks_matmul_attention_probs_value(self, attention_probs: torch.Te assert attention_probs.size(3) == 2 * window_overlap + 1 chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap - chunked_attention_probs = attention_probs.transpose(1, 2).reshape(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1) + chunked_attention_probs = attention_probs.transpose(1, 2).reshape( + batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + ) # group batch_size and num_heads dimensions into one value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) @@ -250,7 +260,12 @@ def _sliding_chunks_matmul_attention_probs_value(self, attention_probs: torch.Te # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) chunked_value_stride = padded_value.stride() - chunked_value_stride = chunked_value_stride[0], window_overlap * chunked_value_stride[1], chunked_value_stride[1], chunked_value_stride[2] + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) chunked_attention_probs = self._pad_by_window_overlap_except_last_row(chunked_attention_probs) @@ -294,7 +309,9 @@ def forward( is_index_global_attention_nonzeros = is_index_global_attention.nonzero(as_tuple=True) # mask indicating which values are actually going to be padded for global attention computation - is_local_index_global_attention = torch.arange(max_num_global_attn_indices_of_batches, device=attention_mask.device) < num_global_attn_indices_per_batch.unsqueeze(dim=-1) + is_local_index_global_attention = torch.arange( + max_num_global_attn_indices_of_batches, device=attention_mask.device + ) < num_global_attn_indices_per_batch.unsqueeze(dim=-1) # 2) location of the non-padding values in the selected global attention is_local_index_global_attention_indices = is_local_index_global_attention.nonzero(as_tuple=True) @@ -309,7 +326,9 @@ def forward( value_vectors = self.value(hidden_states) seq_len, batch_size, embed_dim = hidden_states.size() - assert embed_dim == self.embed_dim, f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + assert ( + embed_dim == self.embed_dim + ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" # normalize query query_vectors /= math.sqrt(self.head_dim) @@ -318,19 +337,21 @@ def forward( key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # local_attention_probs = (batch_size, seq_len, num_heads, window*2+1) - local_attention_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attention_window_size) + local_attention_probs = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attention_window_size + ) # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) - remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze( - dim=-1 - ) + remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) # cast to fp32/fp16 then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( remove_from_windowed_attention_mask, -10000.0 ) # diagonal mask with zeros everywhere and -inf inplace of padding - diagonal_mask = self._sliding_chunks_query_key_matmul(float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attention_window_size) + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attention_window_size + ) # pad local attention probs local_attention_probs += diagonal_mask @@ -345,12 +366,18 @@ def forward( # compute local attention probs from global attention keys and contact over window dim if max_num_global_attn_indices_of_batches > 0: # create only global key vectors - key_vectors_only_global = key_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) - key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[is_index_global_attention_nonzeros] + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim + ) + key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[ + is_index_global_attention_nonzeros + ] # (batch_size, seq_len, num_heads, max_num_global_attn_indices_of_batches) attention_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) - attention_probs_from_global_key[local_index_no_global_attention_indices[0], :, :, local_index_no_global_attention_indices[1]] = -10000.0 + attention_probs_from_global_key[ + local_index_no_global_attention_indices[0], :, :, local_index_no_global_attention_indices[1] + ] = -10000.0 # concat to local_attention_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) @@ -359,14 +386,18 @@ def forward( # free memory del key_vectors, query_vectors, key_vectors_only_global, attention_probs_from_global_key - local_attention_probs_fp32 = F.softmax(local_attention_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + local_attention_probs_fp32 = F.softmax( + local_attention_probs, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability local_attention_probs = local_attention_probs_fp32.type_as(local_attention_probs) # free memory del local_attention_probs_fp32 # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - local_attention_probs = torch.masked_fill(local_attention_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) + local_attention_probs = torch.masked_fill( + local_attention_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0 + ) local_attention_probs = F.dropout(local_attention_probs, p=self.dropout, training=self.training) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) @@ -375,26 +406,43 @@ def forward( if max_num_global_attn_indices_of_batches > 0: local_attn_probs_only_global = local_attention_probs.narrow(-1, 0, max_num_global_attn_indices_of_batches) -# only_global_value_vectors = value_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) + # only_global_value_vectors = value_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) - value_vectors_only_global = value_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) - value_vectors_only_global[is_local_index_global_attention_indices] = value_vectors[is_index_global_attention_nonzeros] + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attention_indices] = value_vectors[ + is_index_global_attention_nonzeros + ] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - local_attention_output_only_global = torch.matmul(local_attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2)).transpose(1, 2) + local_attention_output_only_global = torch.matmul( + local_attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) + ).transpose(1, 2) local_attn_probs_without_global = local_attention_probs.narrow( - -1, max_num_global_attn_indices_of_batches, local_attention_probs.size(-1) - max_num_global_attn_indices_of_batches + -1, + max_num_global_attn_indices_of_batches, + local_attention_probs.size(-1) - max_num_global_attn_indices_of_batches, ).contiguous() # add computed attention output - attention_output = local_attention_output_only_global + self._sliding_chunks_matmul_attention_probs_value(local_attn_probs_without_global, value_vectors, self.one_sided_attention_window_size) + attention_output = local_attention_output_only_global + self._sliding_chunks_matmul_attention_probs_value( + local_attn_probs_without_global, value_vectors, self.one_sided_attention_window_size + ) # free memory - del local_attn_probs_only_global, value_vectors_only_global, local_attention_output_only_global, local_attn_probs_without_global + del ( + local_attn_probs_only_global, + value_vectors_only_global, + local_attention_output_only_global, + local_attn_probs_without_global, + ) else: # compute local attention - attention_output = self._sliding_chunks_matmul_attention_probs_value(local_attention_probs, value_vectors, self.one_sided_attention_window_size) + attention_output = self._sliding_chunks_matmul_attention_probs_value( + local_attention_probs, value_vectors, self.one_sided_attention_window_size + ) assert attention_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" attention_output = attention_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() @@ -402,7 +450,9 @@ def forward( # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation if max_num_global_attn_indices_of_batches > 0: - only_global_attention_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices_of_batches, batch_size, embed_dim) + only_global_attention_hidden_states = hidden_states.new_zeros( + max_num_global_attn_indices_of_batches, batch_size, embed_dim + ) only_global_attention_hidden_states[is_local_index_global_attention_indices[::-1]] = hidden_states[ is_index_global_attention_nonzeros[::-1] ] @@ -426,20 +476,34 @@ def forward( global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) ) # batch_size * self.num_heads, seq_len, head_dim) global_attention_probs = torch.bmm(only_global_query_vectors, global_key_vectors.transpose(1, 2)) - assert list(global_attention_probs.size()) == [batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len], f"global_attention_probs have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len)}, but is {global_attention_probs.size()}." + assert list(global_attention_probs.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices_of_batches, + seq_len, + ], f"global_attention_probs have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len)}, but is {global_attention_probs.size()}." - global_attention_probs = global_attention_probs.view(batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len) + global_attention_probs = global_attention_probs.view( + batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len + ) - global_attention_probs[local_index_no_global_attention_indices[0], :, local_index_no_global_attention_indices[1], :] = -10000.0 + global_attention_probs[ + local_index_no_global_attention_indices[0], :, local_index_no_global_attention_indices[1], : + ] = -10000.0 - global_attention_probs = global_attention_probs.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) + global_attention_probs = global_attention_probs.masked_fill( + is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0, + ) - global_attention_probs = global_attention_probs.view(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len) + global_attention_probs = global_attention_probs.view( + batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len + ) global_attention_probs_float = F.softmax( global_attention_probs, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability - global_attention_probs = F.dropout(global_attention_probs_float.type_as(global_attention_probs), p=self.dropout, training=self.training) + global_attention_probs = F.dropout( + global_attention_probs_float.type_as(global_attention_probs), p=self.dropout, training=self.training + ) global_attention_output = torch.bmm(global_attention_probs, global_key_vectors) @@ -471,7 +535,9 @@ def forward( # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attention_probs are padded with -10000.0 attention scores - local_attention_probs = local_attention_probs.view(batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len) + local_attention_probs = local_attention_probs.view( + batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len + ) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 0ff511decbc6..f49f3ca2dfe9 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -144,8 +144,15 @@ def create_and_check_longformer_model_with_global_attention_mask( global_attention_mask[:, input_mask.shape[-1] // 2] = 0 global_attention_mask = global_attention_mask.to(torch_device) - sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids) - sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask) + sequence_output, pooled_output = model( + input_ids, + attention_mask=input_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + ) + sequence_output, pooled_output = model( + input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask + ) sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask) result = { From e66015bc61a7d5667ff97e26baad9639151ad874 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Jun 2020 19:18:09 +0000 Subject: [PATCH 11/19] refactor naming --- src/transformers/modeling_longformer.py | 45 +++++++++++++------------ tests/test_modeling_longformer.py | 34 +++++++++++++++---- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index aceafff5d07e..7c07c4d6c54c 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -426,9 +426,13 @@ def forward( max_num_global_attn_indices_of_batches, local_attention_probs.size(-1) - max_num_global_attn_indices_of_batches, ).contiguous() + # add computed attention output - attention_output = local_attention_output_only_global + self._sliding_chunks_matmul_attention_probs_value( - local_attn_probs_without_global, value_vectors, self.one_sided_attention_window_size + attention_output = ( + self._sliding_chunks_matmul_attention_probs_value( + local_attn_probs_without_global, value_vectors, self.one_sided_attention_window_size + ) + + local_attention_output_only_global ) # free memory @@ -505,7 +509,7 @@ def forward( global_attention_probs_float.type_as(global_attention_probs), p=self.dropout, training=self.training ) - global_attention_output = torch.bmm(global_attention_probs, global_key_vectors) + global_attention_output = torch.bmm(global_attention_probs, global_value_vectors) assert list(global_attention_output.size()) == [ batch_size * self.num_heads, @@ -815,12 +819,6 @@ def _pad_to_window_size( ) if input_ids is not None: input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) - if attention_mask is not None: - attention_mask = F.pad( - attention_mask, (0, padding_len), value=False - ) # no attention on the padding tokens - if token_type_ids is not None: - token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 if position_ids is not None: # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) @@ -831,6 +829,9 @@ def _pad_to_window_size( inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens + token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): @@ -900,19 +901,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - # merge `global_attention_mask` and `attention_mask` - if global_attention_mask is not None: - attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) - - padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - pad_token_id=self.config.pad_token_id, - ) - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -929,6 +917,19 @@ def forward( if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index f49f3ca2dfe9..59b55866543a 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -115,6 +115,18 @@ def prepare_config_and_inputs(self): def check_loss_output(self, result): self.parent.assertListEqual(list(result["loss"].size()), []) + def create_and_check_attention_mask_determinism( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = LongformerModel(config=config) + model.to(torch_device) + model.eval() + + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + output_with_mask = model(input_ids, attention_mask=attention_mask)[0] + output_without_mask = model(input_ids)[0] + self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4)) + def create_and_check_longformer_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -328,6 +340,10 @@ def test_longformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_model(*config_and_inputs) + def test_longformer_model_attention_mask_determinism(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) + def test_longformer_model_global_attention_mask(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) @@ -361,10 +377,13 @@ def test_inference_no_head(self): # 'Hello world!' input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device) - output = model(input_ids)[0] + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + output = model(input_ids, attention_mask=attention_mask)[0] + output_without_mask = model(input_ids)[0] - expected_output_slice = torch.tensor([0.0105, 0.1807, -0.2062, -0.0643, 0.0629], device=torch_device) + expected_output_slice = torch.tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], device=torch_device) self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4)) + self.assertTrue(torch.allclose(output_without_mask[0, 0, -5:], expected_output_slice, atol=1e-4)) @slow def test_inference_no_head_long(self): @@ -377,9 +396,10 @@ def test_inference_no_head_long(self): ) # long input attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) - attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions + global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device) + global_attention_mask[:, [1, 4, 21]] = 1 # Set global attention on a few random positions - output = model(input_ids, attention_mask=attention_mask)[0] + output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0] expected_output_sum = torch.tensor(74585.8594, device=torch_device) expected_output_mean = torch.tensor(0.0243, device=torch_device) @@ -398,9 +418,9 @@ def test_inference_masked_lm_long(self): loss, prediction_scores = model(input_ids, labels=input_ids) - expected_loss = torch.tensor(0.0620, device=torch_device) - expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device) - expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device) + expected_loss = torch.tensor(0.0074, device=torch_device) + expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device) + expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device) input_ids = input_ids.to(torch_device) self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4)) From a3cf0e7d15cbe7329a1519bc141797ab4cb43233 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:28:05 +0200 Subject: [PATCH 12/19] save intermed --- src/transformers/modeling_longformer.py | 182 ++++++++++++------------ 1 file changed, 93 insertions(+), 89 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 7c07c4d6c54c..cfb437e4a097 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -289,34 +289,10 @@ def forward( attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) - # key tokens to be padded + # is index masked or global attention is_index_masked = attention_mask < 0 - - # all global attention tokens - is_index_global_attention = attention_mask > 0 - - # how many global attention tokens - num_global_attn_indices_per_batch = is_index_global_attention.long().sum(dim=1) - - # max global attention tokens of all batches - max_num_global_attn_indices_of_batches = num_global_attn_indices_per_batch.max() - - if max_num_global_attn_indices_of_batches > 0: - # To support the case of variable number of global attention in the rows of a batch, - # we use the following three selection masks to select global attention embeddings - # in a 3d tensor and pad it to `max_num_global_attn_indices_of_batches` - # 1) selecting embeddings that correspond to global attention - is_index_global_attention_nonzeros = is_index_global_attention.nonzero(as_tuple=True) - - # mask indicating which values are actually going to be padded for global attention computation - is_local_index_global_attention = torch.arange( - max_num_global_attn_indices_of_batches, device=attention_mask.device - ) < num_global_attn_indices_per_batch.unsqueeze(dim=-1) - - # 2) location of the non-padding values in the selected global attention - is_local_index_global_attention_indices = is_local_index_global_attention.nonzero(as_tuple=True) - # 3) location of the padding values in the selected global attention - local_index_no_global_attention_indices = (is_local_index_global_attention == 0).nonzero(as_tuple=True) + is_index_global_attn = attention_mask > 0 + is_global_attn = torch.all(is_index_global_attn).item() hidden_states = hidden_states.transpose(0, 1) @@ -336,8 +312,8 @@ def forward( query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - # local_attention_probs = (batch_size, seq_len, num_heads, window*2+1) - local_attention_probs = self._sliding_chunks_query_key_matmul( + # attention_probs = (batch_size, seq_len, num_heads, window*2+1) + attention_probs = self._sliding_chunks_query_key_matmul( query_vectors, key_vectors, self.one_sided_attention_window_size ) @@ -354,65 +330,55 @@ def forward( ) # pad local attention probs - local_attention_probs += diagonal_mask + attention_probs += diagonal_mask - assert list(local_attention_probs.size()) == [ + assert list(attention_probs.size()) == [ batch_size, seq_len, self.num_heads, self.one_sided_attention_window_size * 2 + 1, - ], f"local_attention_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attention_window_size * 2 + 1}), but is of size {local_attention_probs.size()}" + ], f"attention_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attention_window_size * 2 + 1}), but is of size {attention_probs.size()}" # compute local attention probs from global attention keys and contact over window dim - if max_num_global_attn_indices_of_batches > 0: - # create only global key vectors - key_vectors_only_global = key_vectors.new_zeros( - batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim - ) - key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[ - is_index_global_attention_nonzeros - ] - - # (batch_size, seq_len, num_heads, max_num_global_attn_indices_of_batches) - attention_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) - attention_probs_from_global_key[ - local_index_no_global_attention_indices[0], :, :, local_index_no_global_attention_indices[1] - ] = -10000.0 - - # concat to local_attention_probs + if is_global_attn: + # compute global attn indices required through out forward fn + max_num_global_attn_indices, is_index_global_attn_nonzero, is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + global_key_attention_probs = self._concat_with_global_key_attn_probs(attention_probs, key_vectors, max_num_global_attn_indices, is_local_index_global_attn_nonzero, is_index_global_attn_nonzero) + # concat to attention_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - local_attention_probs = torch.cat((attention_probs_from_global_key, local_attention_probs), dim=-1) + attention_probs = torch.cat((global_key_attention_probs, attention_probs), dim=-1) # free memory - del key_vectors, query_vectors, key_vectors_only_global, attention_probs_from_global_key + del global_key_attention_probs - local_attention_probs_fp32 = F.softmax( - local_attention_probs, dim=-1, dtype=torch.float32 + attention_probs_fp32 = F.softmax( + attention_probs, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability - local_attention_probs = local_attention_probs_fp32.type_as(local_attention_probs) + attention_probs = attention_probs_fp32.type_as(attention_probs) # free memory - del local_attention_probs_fp32 + del attention_probs_fp32 # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - local_attention_probs = torch.masked_fill( - local_attention_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0 + attention_probs = torch.masked_fill( + attention_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0 ) - local_attention_probs = F.dropout(local_attention_probs, p=self.dropout, training=self.training) + attention_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # compute local attention output with global attention value and add - if max_num_global_attn_indices_of_batches > 0: - local_attn_probs_only_global = local_attention_probs.narrow(-1, 0, max_num_global_attn_indices_of_batches) + if is_global_attn: + local_attn_probs_only_global = attention_probs.narrow(-1, 0, max_num_global_attn_indices) - # only_global_value_vectors = value_vectors.new_zeros(batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim) + # only_global_value_vectors = value_vectors.new_zeros(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim) value_vectors_only_global = value_vectors.new_zeros( - batch_size, max_num_global_attn_indices_of_batches, self.num_heads, self.head_dim + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim ) - value_vectors_only_global[is_local_index_global_attention_indices] = value_vectors[ - is_index_global_attention_nonzeros + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[ + is_index_global_attn_nonzero ] # use `matmul` because `einsum` crashes sometimes with fp16 @@ -421,10 +387,10 @@ def forward( local_attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) ).transpose(1, 2) - local_attn_probs_without_global = local_attention_probs.narrow( + local_attn_probs_without_global = attention_probs.narrow( -1, - max_num_global_attn_indices_of_batches, - local_attention_probs.size(-1) - max_num_global_attn_indices_of_batches, + max_num_global_attn_indices, + attention_probs.size(-1) - max_num_global_attn_indices, ).contiguous() # add computed attention output @@ -434,6 +400,7 @@ def forward( ) + local_attention_output_only_global ) +# attn_output_to_global_indices = self._compute_attn_output_to_global_indices(value_vectors, batch_size, max_num_global_attn_indices, is_local_index_global_attn_nonzero) # free memory del ( @@ -445,7 +412,7 @@ def forward( else: # compute local attention attention_output = self._sliding_chunks_matmul_attention_probs_value( - local_attention_probs, value_vectors, self.one_sided_attention_window_size + attention_probs, value_vectors, self.one_sided_attention_window_size ) assert attention_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" @@ -453,12 +420,12 @@ def forward( # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation - if max_num_global_attn_indices_of_batches > 0: + if is_global_attn: only_global_attention_hidden_states = hidden_states.new_zeros( - max_num_global_attn_indices_of_batches, batch_size, embed_dim + max_num_global_attn_indices, batch_size, embed_dim ) - only_global_attention_hidden_states[is_local_index_global_attention_indices[::-1]] = hidden_states[ - is_index_global_attention_nonzeros[::-1] + only_global_attention_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] ] only_global_query_vectors = self.query_global(only_global_attention_hidden_states) @@ -470,9 +437,9 @@ def forward( only_global_query_vectors = ( only_global_query_vectors.contiguous() - .view(max_num_global_attn_indices_of_batches, batch_size * self.num_heads, self.head_dim) + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) .transpose(0, 1) - ) # (batch_size * self.num_heads, max_num_global_attn_indices_of_batches, head_dim) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) global_key_vectors = ( global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) ) # batch_size * self.num_heads, seq_len, head_dim) @@ -482,16 +449,16 @@ def forward( global_attention_probs = torch.bmm(only_global_query_vectors, global_key_vectors.transpose(1, 2)) assert list(global_attention_probs.size()) == [ batch_size * self.num_heads, - max_num_global_attn_indices_of_batches, + max_num_global_attn_indices, seq_len, - ], f"global_attention_probs have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len)}, but is {global_attention_probs.size()}." + ], f"global_attention_probs have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attention_probs.size()}." global_attention_probs = global_attention_probs.view( - batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len + batch_size, self.num_heads, max_num_global_attn_indices, seq_len ) global_attention_probs[ - local_index_no_global_attention_indices[0], :, local_index_no_global_attention_indices[1], : + is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : ] = -10000.0 global_attention_probs = global_attention_probs.masked_fill( @@ -499,7 +466,7 @@ def forward( ) global_attention_probs = global_attention_probs.view( - batch_size * self.num_heads, max_num_global_attn_indices_of_batches, seq_len + batch_size * self.num_heads, max_num_global_attn_indices, seq_len ) global_attention_probs_float = F.softmax( global_attention_probs, dim=-1, dtype=torch.float32 @@ -513,44 +480,81 @@ def forward( assert list(global_attention_output.size()) == [ batch_size * self.num_heads, - max_num_global_attn_indices_of_batches, + max_num_global_attn_indices, self.head_dim, - ], f"global_attention_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices_of_batches, self.head_dim)}, but is {global_attention_output.size()}." + ], f"global_attention_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attention_output.size()}." global_attention_output = global_attention_output.view( - batch_size, self.num_heads, max_num_global_attn_indices_of_batches, self.head_dim + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim ) nonzero_global_attention_output = global_attention_output[ - is_local_index_global_attention_indices[0], :, is_local_index_global_attention_indices[1] + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] ] # overwrite values with global attention - attention_output[is_index_global_attention_nonzeros[::-1]] = nonzero_global_attention_output.view( - len(is_local_index_global_attention_indices[0]), -1 + attention_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attention_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 ) attention_output = attention_output.transpose(0, 1) if output_attentions: - if max_num_global_attn_indices_of_batches > 0: + if is_global_attn: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attention_probs are padded with -10000.0 attention scores - local_attention_probs = local_attention_probs.view( - batch_size, self.num_heads, max_num_global_attn_indices_of_batches, seq_len + attention_probs = attention_probs.view( + batch_size, self.num_heads, max_num_global_attn_indices, seq_len ) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours - local_attention_probs = local_attention_probs.permute(0, 2, 1, 3) + attention_probs = attention_probs.permute(0, 2, 1, 3) - outputs = (attention_output, local_attention_probs) if output_attentions else (attention_output,) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) return outputs + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """ compute global attn indices required through out forward fn """ + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # helper variable + is_local_index_global_attn = torch.arange(max_num_global_attn_indices, device=is_index_global_attn.device) < num_global_attn_indices.unsqueeze(dim=-1) + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return max_num_global_attn_indices, is_index_global_attn_nonzero, is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero + + def _concat_with_global_key_attn_probs(self, key_vectors, query_vectors, max_num_global_attn_indices, is_local_index_global_attention_indices, is_index_global_attn_nonzero): + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[ + is_index_global_attn_nonzero + ] + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attention_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + attention_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + ] = -10000.0 + return attention_probs_from_global_key + class LongformerAttention(nn.Module): def __init__(self, config, layer_id=0): From 977f3e27e32107d9ab26b336d1259f43141e18f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 18:11:02 +0200 Subject: [PATCH 13/19] refactor functions --- src/transformers/modeling_longformer.py | 416 +++++++++++++----------- 1 file changed, 230 insertions(+), 186 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index cfb437e4a097..c4a08cabccc7 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -114,7 +114,7 @@ def __init__(self, config, layer_id): attention_window > 0 ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" - self.one_sided_attention_window_size = attention_window // 2 + self.one_sided_attn_window_size = attention_window // 2 @staticmethod def pad_and_transpose_last_two_dims(hidden_states_padded, padding): @@ -236,18 +236,18 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso self._mask_invalid_locations(diagonal_attention_scores, window_overlap) return diagonal_attention_scores - def _sliding_chunks_matmul_attention_probs_value( - self, attention_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int ): - """Same as _sliding_chunks_query_key_matmul but for attention_probs and value tensors. It is expecting the same output + """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. It is expecting the same output format from _sliding_chunks_query_key_matmul""" batch_size, seq_len, num_heads, head_dim = value.size() assert seq_len % (window_overlap * 2) == 0 - assert attention_probs.size()[:3] == value.size()[:3] - assert attention_probs.size(3) == 2 * window_overlap + 1 + assert attn_probs.size()[:3] == value.size()[:3] + assert attn_probs.size(3) == 2 * window_overlap + 1 chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap - chunked_attention_probs = attention_probs.transpose(1, 2).reshape( + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 ) @@ -268,9 +268,9 @@ def _sliding_chunks_matmul_attention_probs_value( ) chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) - chunked_attention_probs = self._pad_by_window_overlap_except_last_row(chunked_attention_probs) + chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs) - context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attention_probs, chunked_value)) + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) def forward( @@ -312,10 +312,8 @@ def forward( query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - # attention_probs = (batch_size, seq_len, num_heads, window*2+1) - attention_probs = self._sliding_chunks_query_key_matmul( - query_vectors, key_vectors, self.one_sided_attention_window_size - ) + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size) # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) @@ -326,177 +324,99 @@ def forward( ) # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attention_window_size + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size ) # pad local attention probs - attention_probs += diagonal_mask + attn_probs += diagonal_mask - assert list(attention_probs.size()) == [ + assert list(attn_probs.size()) == [ batch_size, seq_len, self.num_heads, - self.one_sided_attention_window_size * 2 + 1, - ], f"attention_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attention_window_size * 2 + 1}), but is of size {attention_probs.size()}" + self.one_sided_attn_window_size * 2 + 1, + ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_probs.size()}" # compute local attention probs from global attention keys and contact over window dim if is_global_attn: # compute global attn indices required through out forward fn - max_num_global_attn_indices, is_index_global_attn_nonzero, is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero = self._get_global_attn_indices(is_index_global_attn) + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) # calculate global attn probs from global key - global_key_attention_probs = self._concat_with_global_key_attn_probs(attention_probs, key_vectors, max_num_global_attn_indices, is_local_index_global_attn_nonzero, is_index_global_attn_nonzero) - # concat to attention_probs + global_key_attn_probs = self._concat_with_global_key_attn_probs( + attn_probs=attn_probs, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to attn_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attention_probs = torch.cat((global_key_attention_probs, attention_probs), dim=-1) + attn_probs = torch.cat((global_key_attn_probs, attn_probs), dim=-1) # free memory - del global_key_attention_probs + del global_key_attn_probs - attention_probs_fp32 = F.softmax( - attention_probs, dim=-1, dtype=torch.float32 - ) # use fp32 for numerical stability - attention_probs = attention_probs_fp32.type_as(attention_probs) + attn_probs_fp32 = F.softmax(attn_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_probs = attn_probs_fp32.type_as(attn_probs) # free memory - del attention_probs_fp32 + del attn_probs_fp32 # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attention_probs = torch.masked_fill( - attention_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0 - ) + attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) + + # apply dropout + attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) - attention_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # compute local attention output with global attention value and add if is_global_attn: - local_attn_probs_only_global = attention_probs.narrow(-1, 0, max_num_global_attn_indices) - - # only_global_value_vectors = value_vectors.new_zeros(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim) - - value_vectors_only_global = value_vectors.new_zeros( - batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim - ) - value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[ - is_index_global_attn_nonzero - ] - - # use `matmul` because `einsum` crashes sometimes with fp16 - # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - local_attention_output_only_global = torch.matmul( - local_attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) - ).transpose(1, 2) - - local_attn_probs_without_global = attention_probs.narrow( - -1, - max_num_global_attn_indices, - attention_probs.size(-1) - max_num_global_attn_indices, - ).contiguous() - - # add computed attention output - attention_output = ( - self._sliding_chunks_matmul_attention_probs_value( - local_attn_probs_without_global, value_vectors, self.one_sided_attention_window_size - ) - + local_attention_output_only_global - ) -# attn_output_to_global_indices = self._compute_attn_output_to_global_indices(value_vectors, batch_size, max_num_global_attn_indices, is_local_index_global_attn_nonzero) - - # free memory - del ( - local_attn_probs_only_global, - value_vectors_only_global, - local_attention_output_only_global, - local_attn_probs_without_global, + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, ) else: - # compute local attention - attention_output = self._sliding_chunks_matmul_attention_probs_value( - attention_probs, value_vectors, self.one_sided_attention_window_size + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size ) - assert attention_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" - attention_output = attention_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation if is_global_attn: - only_global_attention_hidden_states = hidden_states.new_zeros( - max_num_global_attn_indices, batch_size, embed_dim - ) - only_global_attention_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ - is_index_global_attn_nonzero[::-1] - ] - - only_global_query_vectors = self.query_global(only_global_attention_hidden_states) - global_key_vectors = self.key_global(hidden_states) - global_value_vectors = self.value_global(hidden_states) - - # normalize - only_global_query_vectors /= math.sqrt(self.head_dim) - - only_global_query_vectors = ( - only_global_query_vectors.contiguous() - .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) - .transpose(0, 1) - ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) - global_key_vectors = ( - global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seq_len, head_dim) - global_value_vectors = ( - global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seq_len, head_dim) - global_attention_probs = torch.bmm(only_global_query_vectors, global_key_vectors.transpose(1, 2)) - assert list(global_attention_probs.size()) == [ - batch_size * self.num_heads, - max_num_global_attn_indices, - seq_len, - ], f"global_attention_probs have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attention_probs.size()}." - - global_attention_probs = global_attention_probs.view( - batch_size, self.num_heads, max_num_global_attn_indices, seq_len - ) - - global_attention_probs[ - is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : - ] = -10000.0 - - global_attention_probs = global_attention_probs.masked_fill( - is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0, - ) - - global_attention_probs = global_attention_probs.view( - batch_size * self.num_heads, max_num_global_attn_indices, seq_len + global_attn_output = self.compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, ) - global_attention_probs_float = F.softmax( - global_attention_probs, dim=-1, dtype=torch.float32 - ) # use fp32 for numerical stability - global_attention_probs = F.dropout( - global_attention_probs_float.type_as(global_attention_probs), p=self.dropout, training=self.training - ) - - global_attention_output = torch.bmm(global_attention_probs, global_value_vectors) - - assert list(global_attention_output.size()) == [ - batch_size * self.num_heads, - max_num_global_attn_indices, - self.head_dim, - ], f"global_attention_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attention_output.size()}." - - global_attention_output = global_attention_output.view( - batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim - ) - nonzero_global_attention_output = global_attention_output[ + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] ] - # overwrite values with global attention - attention_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attention_output.view( + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( len(is_local_index_global_attn_nonzero[0]), -1 ) - attention_output = attention_output.transpose(0, 1) + attn_output = attn_output.transpose(0, 1) if output_attentions: if is_global_attn: @@ -505,55 +425,179 @@ def forward( # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, - # attention_probs are padded with -10000.0 attention scores - attention_probs = attention_probs.view( - batch_size, self.num_heads, max_num_global_attn_indices, seq_len - ) + # attn_probs are padded with -10000.0 attention scores + attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours - attention_probs = attention_probs.permute(0, 2, 1, 3) + attn_probs = attn_probs.permute(0, 2, 1, 3) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + outputs = (attn_output, attn_probs) if output_attentions else (attn_output,) return outputs - @staticmethod - def _get_global_attn_indices(is_index_global_attn): - """ compute global attn indices required through out forward fn """ - # helper variable - num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """ compute global attn indices required through out forward fn """ + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # helper variable + is_local_index_global_attn = torch.arange( + max_num_global_attn_indices, device=is_index_global_attn.device + ) < num_global_attn_indices.unsqueeze(dim=-1) + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) - # max number of global attn indices in batch - max_num_global_attn_indices = num_global_attn_indices.max() + def _concat_with_global_key_attn_probs( + self, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attention_indices, + is_local_index_no_global_attn_nonzero, + ): + batch_size = key_vectors.shape[0] - # helper variable - is_local_index_global_attn = torch.arange(max_num_global_attn_indices, device=is_index_global_attn.device) < num_global_attn_indices.unsqueeze(dim=-1) + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[is_index_global_attn_nonzero] + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + attn_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + ] = -10000.0 + return attn_probs_from_global_key + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = attn_probs.shape[0] - # indices of global attn - is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) + # get value vectors for global only + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] + + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + # compute attn output only global + attn_output_only_global = torch.matmul( + attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) + ).transpose(1, 2) + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_only_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global - # location of the non-padding values within global attention indices - is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + def compute_global_attn_output_from_hidden( + self, + hidden_states, + max_num_global_attn_indices, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + ): + batch_size, seq_len = hidden_states.shape[:2] - # location of the padding values within global attention indices - is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) - return max_num_global_attn_indices, is_index_global_attn_nonzero, is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero + # prepare global hidden states + global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) + global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] + ] - def _concat_with_global_key_attn_probs(self, key_vectors, query_vectors, max_num_global_attn_indices, is_local_index_global_attention_indices, is_index_global_attn_nonzero): - # create only global key vectors - key_vectors_only_global = key_vectors.new_zeros( - batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim - ) - key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[ - is_index_global_attn_nonzero - ] - # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attention_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) - attention_probs_from_global_key[ - is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] - ] = -10000.0 - return attention_probs_from_global_key + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= math.sqrt(self.head_dim) + + # reshape + global_query_vectors_only_global = ( + global_query_vectors_only_global.contiguous() + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + + # compute attn scores + global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) + + assert list(global_attn_scores.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + seq_len, + ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}." + + global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + + global_attn_scores[ + is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : + ] = -10000.0 + + global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) + + global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + + # compute global attn probs + global_attn_probs_float = F.softmax( + global_attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + global_attn_probs = F.dropout( + global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training + ) + + # global attn output + global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) + + assert list(global_attn_output.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + self.head_dim, + ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}." + + global_attn_output = global_attn_output.view( + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim + ) + return global_attn_output class LongformerAttention(nn.Module): @@ -585,8 +629,8 @@ def forward( self, hidden_states, attention_mask=None, output_attentions=False, ): self_outputs = self.self(hidden_states, attention_mask, output_attentions,) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + attn_output = self.output(self_outputs[0], hidden_states) + outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -600,12 +644,12 @@ def __init__(self, config, layer_id=0): def forward( self, hidden_states, attention_mask=None, output_attentions=False, ): - self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) + attn_output = self_attn_outputs[0] + outputs = self_attn_outputs[1:] # add self attentions if we output attention weights - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) + intermediate_output = self.intermediate(attn_output) + layer_output = self.output(intermediate_output, attn_output) outputs = (layer_output,) + outputs return outputs From 961ffbb5aa643110f5dfe98df20c508c7b33dd7d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:06:08 +0000 Subject: [PATCH 14/19] finish function refactor --- src/transformers/modeling_longformer.py | 43 ++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index c4a08cabccc7..5562701cd10d 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -292,7 +292,7 @@ def forward( # is index masked or global attention is_index_masked = attention_mask < 0 is_index_global_attn = attention_mask > 0 - is_global_attn = torch.all(is_index_global_attn).item() + is_global_attn = not torch.all(~is_index_global_attn).item() hidden_states = hidden_states.transpose(0, 1) @@ -313,7 +313,9 @@ def forward( key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # attn_probs = (batch_size, seq_len, num_heads, window*2+1) - attn_probs = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) @@ -328,14 +330,14 @@ def forward( ) # pad local attention probs - attn_probs += diagonal_mask + attn_scores += diagonal_mask - assert list(attn_probs.size()) == [ + assert list(attn_scores.size()) == [ batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1, - ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_probs.size()}" + ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" # compute local attention probs from global attention keys and contact over window dim if is_global_attn: @@ -347,8 +349,8 @@ def forward( is_local_index_no_global_attn_nonzero, ) = self._get_global_attn_indices(is_index_global_attn) # calculate global attn probs from global key - global_key_attn_probs = self._concat_with_global_key_attn_probs( - attn_probs=attn_probs, + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, key_vectors=key_vectors, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, @@ -357,13 +359,13 @@ def forward( ) # concat to attn_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attn_probs = torch.cat((global_key_attn_probs, attn_probs), dim=-1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) # free memory - del global_key_attn_probs + del global_key_attn_scores - attn_probs_fp32 = F.softmax(attn_probs, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attn_probs = attn_probs_fp32.type_as(attn_probs) + attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_probs = attn_probs_fp32.type_as(attn_scores) # free memory del attn_probs_fp32 @@ -445,14 +447,14 @@ def _get_global_attn_indices(is_index_global_attn): # max number of global attn indices in batch max_num_global_attn_indices = num_global_attn_indices.max() + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + # helper variable is_local_index_global_attn = torch.arange( max_num_global_attn_indices, device=is_index_global_attn.device ) < num_global_attn_indices.unsqueeze(dim=-1) - # indices of global attn - is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) - # location of the non-padding values within global attention indices is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) @@ -471,7 +473,7 @@ def _concat_with_global_key_attn_probs( query_vectors, max_num_global_attn_indices, is_index_global_attn_nonzero, - is_local_index_global_attention_indices, + is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, ): batch_size = key_vectors.shape[0] @@ -480,7 +482,7 @@ def _concat_with_global_key_attn_probs( key_vectors_only_global = key_vectors.new_zeros( batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim ) - key_vectors_only_global[is_local_index_global_attention_indices] = key_vectors[is_index_global_attn_nonzero] + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] # (batch_size, seq_len, num_heads, max_num_global_attn_indices) attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) attn_probs_from_global_key[ @@ -513,9 +515,14 @@ def _compute_attn_output_with_global_indices( attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) ).transpose(1, 2) + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() + # compute attn output with global attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( - attn_probs_only_global, value_vectors, self.one_sided_attn_window_size + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size ) return attn_output_only_global + attn_output_without_global @@ -528,7 +535,7 @@ def compute_global_attn_output_from_hidden( is_local_index_no_global_attn_nonzero, is_index_masked, ): - batch_size, seq_len = hidden_states.shape[:2] + seq_len, batch_size = hidden_states.shape[:2] # prepare global hidden states global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) From 3d4952a8af4e18f5838f27d8d008cd9a91b474e0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:29:11 +0000 Subject: [PATCH 15/19] fix tests --- src/transformers/modeling_longformer.py | 3 +++ tests/test_modeling_longformer.py | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 5562701cd10d..bf0eb47a98fc 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -1552,6 +1552,9 @@ def forward( pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) + print(f"logits: {logits.shape}") + print(f"pooled_output: {pooled_output.shape}") + print(f"num_choices: {num_choices}") reshaped_logits = logits.view(-1, num_choices) outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 59b55866543a..1819e1aa0605 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -285,7 +285,8 @@ def prepare_config_and_inputs_for_common(self): token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + global_attention_mask = torch.zeros_like(input_ids) + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask, "global_attention_mask": global_attention_mask} return config, inputs_dict def prepare_config_and_inputs_for_question_answering(self): @@ -319,11 +320,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ( LongformerModel, LongformerForMaskedLM, - # TODO: make tests pass for those models - # LongformerForSequenceClassification, - # LongformerForQuestionAnswering, - # LongformerForTokenClassification, - # LongformerForMultipleChoice, + LongformerForSequenceClassification, + LongformerForQuestionAnswering, + LongformerForTokenClassification, + LongformerForMultipleChoice, ) if is_torch_available() else () From 390bf01b4556a06dbb0b9e6ae9720ecc21fcb080 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:40:21 +0000 Subject: [PATCH 16/19] fix all tests but one --- src/transformers/modeling_longformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index bf0eb47a98fc..5562701cd10d 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -1552,9 +1552,6 @@ def forward( pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - print(f"logits: {logits.shape}") - print(f"pooled_output: {pooled_output.shape}") - print(f"num_choices: {num_choices}") reshaped_logits = logits.view(-1, num_choices) outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here From 149b057ce537f73884ea74316ef2282bfa8f3b93 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 30 Jun 2020 08:31:21 +0000 Subject: [PATCH 17/19] finish longformer --- tests/test_modeling_common.py | 2 +- tests/test_modeling_longformer.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4df1f75011ae..5e6705eaccaf 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -812,7 +812,7 @@ def test_multigpu_data_parallel_forward(self): # Wrap model in nn.DataParallel model = torch.nn.DataParallel(model) with torch.no_grad(): - _ = model(**inputs_dict) + _ = model(**self._prepare_for_class(inputs_dict, model_class)) global_rng = random.Random() diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 1819e1aa0605..fed22060501c 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -286,7 +286,12 @@ def prepare_config_and_inputs_for_common(self): choice_labels, ) = config_and_inputs global_attention_mask = torch.zeros_like(input_ids) - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask, "global_attention_mask": global_attention_mask} + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + "global_attention_mask": global_attention_mask, + } return config, inputs_dict def prepare_config_and_inputs_for_question_answering(self): From d0e9f0fbb67afd640f7a01a104b5933dfad822ca Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Jul 2020 17:20:25 +0200 Subject: [PATCH 18/19] address sams and izs comments --- + | 1562 +++++++++++++++++++++++ src/transformers/modeling_longformer.py | 58 +- 2 files changed, 1592 insertions(+), 28 deletions(-) create mode 100644 + diff --git a/+ b/+ new file mode 100644 index 000000000000..0f133eb7db14 --- /dev/null +++ b/+ @@ -0,0 +1,1562 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Longformer model. """ + +import logging +import math +import warnings + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import functional as F + +from .configuration_longformer import LongformerConfig +from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput +from .modeling_roberta import RobertaEmbeddings, RobertaLMHead +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer + + +logger = logging.getLogger(__name__) + +_TOKENIZER_FOR_DOC = "LongformerTokenizer" + +LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "allenai/longformer-base-4096", + "allenai/longformer-large-4096", + "allenai/longformer-large-4096-finetuned-triviaqa", + "allenai/longformer-base-4096-extra.pos.embd.only", + "allenai/longformer-large-4096-extra.pos.embd.only", + # See all Longformer models at https://huggingface.co/models?filter=longformer +] + + +def _get_question_end_index(input_ids, sep_token_id): + """ + Computes the index of the first occurance of `sep_token_id`. + """ + + sep_token_indices = (input_ids == sep_token_id).nonzero() + batch_size = input_ids.shape[0] + + assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" + assert ( + sep_token_indices.shape[0] == 3 * batch_size + ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error." + + return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] + + +def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True): + """ + Computes global attention mask by putting attention on all tokens + before `sep_token_id` if `before_sep_token is True` else after + `sep_token_id`. + """ + + question_end_index = _get_question_end_index(input_ids, sep_token_id) + question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 + # bool attention mask with True in locations of global attention + attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) + if before_sep_token is True: + attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8) + else: + # last token is separation token and should not be counted and in the middle are two separation tokens + attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * ( + attention_mask.expand_as(input_ids) < input_ids.shape[-1] + ).to(torch.uint8) + + return attention_mask + + +class LongformerSelfAttention(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + + self.query = nn.Linear(config.hidden_size, self.embed_dim) + self.key = nn.Linear(config.hidden_size, self.embed_dim) + self.value = nn.Linear(config.hidden_size, self.embed_dim) + + # separate projection layers for tokens with global attention + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + + self.dropout = config.attention_probs_dropout_prob + + self.layer_id = layer_id + attention_window = config.attention_window[self.layer_id] + assert ( + attention_window % 2 == 0 + ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" + assert ( + attention_window > 0 + ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" + + self.one_sided_attn_window_size = attention_window // 2 + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """pads rows and then flips rows and columns""" + hidden_states_padded = F.pad( + hidden_states_padded, padding + ) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.transpose(-1, -2) + return hidden_states_padded + + @staticmethod + def _pad_by_window_overlap_except_last_row(chunked_hidden_states): + """shift every row 1 step right, converting columns into diagonals""" + total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() + chunked_hidden_states = F.pad( + chunked_hidden_states, (0, window_overlap + 1) + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, -1) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap): + """convert into overlapping chunkings. Chunk size = 2w, overlap size = w""" + + # non-overlapping chunks of size = 2w + hidden_states = hidden_states.view( + hidden_states.size(0), + hidden_states.size(1) // (window_overlap * 2), + window_overlap * 2, + hidden_states.size(2), + ) + + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 + + chunk_stride = list(hidden_states.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) + + def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor: + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) + beginning_mask = beginning_mask_2d[None, :, None, :] + ending_mask = beginning_mask.flip(dims=(1, 3)) + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] + beginning_mask = beginning_mask.expand(beginning_input.size()) + beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] + ending_mask = ending_mask.expand(ending_input.size()) + ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + + def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): + """Matrix multiplication of query and key tensors using with a sliding window attention pattern. + This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) + with an overlap of size window_overlap""" + batch_size, seq_len, num_heads, head_dim = query.size() + assert ( + seq_len % (window_overlap * 2) == 0 + ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + assert query.size() == key.size() + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + chunked_query = self._chunk(query, window_overlap) + chunked_key = self._chunk(key, window_overlap) + + # matrix multipication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap + chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply + + # convert diagonals into columns + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + chunked_attention_scores, padding=(0, 0, 0, 1) + ) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] + # - copying the lower triangle + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seq_len, 2 * window_overlap + 1 + ).transpose(2, 1) + + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores + + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + ): + """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. + Returned tensor will be of the same shape as `attn_probs`""" + batch_size, seq_len, num_heads, head_dim = value.size() + assert seq_len % (window_overlap * 2) == 0 + assert attn_probs.size()[:3] == value.size()[:3] + assert attn_probs.size(3) == 2 * window_overlap + 1 + chunks_count = seq_len // window_overlap - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( + batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + ) + + # group batch_size and num_heads dimensions into one + value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) + + chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs) + + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, + ): + """ + LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. + Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer. + + The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to + -ve: no attention + 0: local attention + +ve: global attention + + """ + + attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) + + # is index masked or global attention + is_index_masked = attention_mask < 0 + is_index_global_attn = attention_mask > 0 + is_global_attn = any(is_index_global_attn).item() + + hidden_states = hidden_states.transpose(0, 1) + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert ( + embed_dim == self.embed_dim + ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + + # normalize query + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) + + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, -10000.0 + ) + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + assert list(attn_scores.size()) == [ + batch_size, + seq_len, + self.num_heads, + self.one_sided_attn_window_size * 2 + 1, + ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" + + # compute local attention probs from global attention keys and contact over window dim + if is_global_attn: + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) + + # free memory + del global_key_attn_scores + + attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_probs = attn_probs_fp32.type_as(attn_scores) + + # free memory + del attn_probs_fp32 + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) + + # apply dropout + attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) + + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # compute local attention output with global attention value and add + if is_global_attn: + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + if is_global_attn: + global_attn_output = self.compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + ) + + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] + ] + # overwrite values with global attention + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 + ) + + attn_output = attn_output.transpose(0, 1) + + if output_attentions: + if is_global_attn: + # With global attention, return global attention probabilities only + # batch_size x num_heads x max_num_global_attention_tokens x sequence_length + # which is the attention weights from tokens with global attention to all tokens + # It doesn't not return local attention + # In case of variable number of global attantion in the rows of a batch, + # attn_probs are padded with -10000.0 attention scores + attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + else: + # without global attention, return local attention probabilities + # batch_size x num_heads x sequence_length x window_size + # which is the attention weights of every token attending to its neighbours + attn_probs = attn_probs.permute(0, 2, 1, 3) + + outputs = (attn_output, attn_probs) if output_attentions else (attn_output,) + return outputs + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """ compute global attn indices required throughout forward pass """ + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # helper variable + is_local_index_global_attn = torch.arange( + max_num_global_attn_indices, device=is_index_global_attn.device + ) < num_global_attn_indices.unsqueeze(dim=-1) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = key_vectors.shape[0] + + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + attn_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + ] = -10000.0 + return attn_probs_from_global_key + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = attn_probs.shape[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) + # get value vectors for global only + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] + + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + # compute attn output only global + attn_output_only_global = torch.matmul( + attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global + + def compute_global_attn_output_from_hidden( + self, + hidden_states, + max_num_global_attn_indices, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + ): + seq_len, batch_size = hidden_states.shape[:2] + + # prepare global hidden states + global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) + global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] + ] + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= math.sqrt(self.head_dim) + + # reshape + global_query_vectors_only_global = ( + global_query_vectors_only_global.contiguous() + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + + # compute attn scores + global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) + + assert list(global_attn_scores.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + seq_len, + ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}." + + global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + + global_attn_scores[ + is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : + ] = -10000.0 + + global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) + + global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + + # compute global attn probs + global_attn_probs_float = F.softmax( + global_attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + global_attn_probs = F.dropout( + global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training + ) + + # global attn output + global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) + + assert list(global_attn_output.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + self.head_dim, + ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}." + + global_attn_output = global_attn_output.view( + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim + ) + return global_attn_output + + +class LongformerAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.self = LongformerSelfAttention(config, layer_id) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, + ): + self_outputs = self.self(hidden_states, attention_mask, output_attentions,) + attn_output = self.output(self_outputs[0], hidden_states) + outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class LongformerLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.attention = LongformerAttention(config, layer_id) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, + ): + self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) + attn_output = self_attn_outputs[0] + outputs = self_attn_outputs[1:] # add self attentions if we output attention weights + + intermediate_output = self.intermediate(attn_output) + layer_output = self.output(intermediate_output, attn_output) + outputs = (layer_output,) + outputs + return outputs + + +class LongformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, attention_mask, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class LongformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained + models. + """ + + config_class = LongformerConfig + base_model_prefix = "longformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +LONGFORMER_START_DOCSTRING = r""" + + This model is a PyTorch `torch.nn.Module `__ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.LongformerConfig`): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +LONGFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.__call__` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to decide the attention given on each token, local attention or global attenion. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also have + global attention. Please refer to the `Longformer paper `__ for more details. + Mask values selected in ``[0, 1]``: + ``0`` for local attention (a sliding window attention), + ``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. +""" + + +@add_start_docstrings( + "The bare Longformer Model outputting raw hidden-states without any specific head on top.", + LONGFORMER_START_DOCSTRING, +) +class LongformerModel(LongformerPreTrainedModel): + """ + This class overrides :class:`~transformers.RobertaModel` to provide the ability to process + long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer + `__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer selfattention + combines a local (sliding window) and global attention to extend to long documents without the O(n^2) increase in + memory and compute. + + The selfattention module `LongformerSelfAttention` implemented here supports the combination of local and + global attention but it lacks support for autoregressive attention and dilated attention. Autoregressive + and dilated attention are more relevant for autoregressive language modeling than finetuning on downstream + tasks. Future release will add support for autoregressive attention, but the support for dilated attention + requires a custom CUDA kernel to be memory and compute efficient. + + """ + + config_class = LongformerConfig + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + self.config = config + + if isinstance(config.attention_window, int): + assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" + assert config.attention_window > 0, "`config.attention_window` has to be positive" + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + else: + assert len(config.attention_window) == config.num_hidden_layers, ( + "`len(config.attention_window)` should equal `config.num_hidden_layers`. " + f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" + ) + + self.embeddings = RobertaEmbeddings(config) + self.encoder = LongformerEncoder(config) + self.pooler = BertPooler(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def _pad_to_window_size( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + pad_token_id: int, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + # padding + attention_window = ( + self.config.attention_window + if isinstance(self.config.attention_window, int) + else max(self.config.attention_window) + ) + + assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" + input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape + batch_size, seq_len = input_shape[:2] + + padding_len = (attention_window - seq_len % attention_window) % attention_window + if padding_len > 0: + logger.info( + "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( + seq_len, seq_len + padding_len, attention_window + ) + ) + if input_ids is not None: + input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings + position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) + if inputs_embeds is not None: + input_ids_padding = inputs_embeds.new_full( + (batch_size, padding_len), self.config.pad_token_id, dtype=torch.long, + ) + inputs_embeds_padding = self.embeddings(input_ids_padding) + inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + + attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens + token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 + + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + + def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + >>> import torch + >>> from transformers import LongformerModel, LongformerTokenizer + + >>> model = LongformerModel.from_pretrained('allenai/longformer-base-4096') + >>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') + + >>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document + >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 + + >>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention + >>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention + >>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example, + ... # classification: the token + ... # QA: question tokens + ... # LM: potentially on the beginning of sentences and paragraphs + >>> sequence_output, pooled_output = model(input_ids, attention_mask=attention_mask) + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + + # undo padding + if padding_len > 0: + # `output` has the following tensors: sequence_output, pooled_output, (hidden_states), (attentions) + # `sequence_output`: unpad because the calling function is expecting a length == input_ids.size(1) + # `pooled_output`: independent of the sequence length + # `hidden_states`: mainly used for debugging and analysis, so keep the padding + # `attentions`: mainly used for debugging and analysis, so keep the padding + outputs = outputs[0][:, :-padding_len], *outputs[1:] + + return outputs + + +@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) +class LongformerForMaskedLM(LongformerPreTrainedModel): + config_class = LongformerConfig + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + + self.longformer = LongformerModel(config) + self.lm_head = RobertaLMHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + >>> import torch + >>> from transformers import LongformerForMaskedLM, LongformerTokenizer + + >>> model = LongformerForMaskedLM.from_pretrained('allenai/longformer-base-4096') + >>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') + + >>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document + >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 + + >>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM + ... # check ``LongformerModel.forward`` for more details how to set `attention_mask` + >>> loss, prediction_scores = model(input_ids, attention_mask=attention_mask, labels=input_ids) + """ + + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + DeprecationWarning, + ) + labels = kwargs.pop("masked_lm_labels") + assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + outputs = (masked_lm_loss,) + outputs + + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Longformer Model transformer with a sequence classification/regression head on top (a linear layer + on top of the pooled output) e.g. for GLUE tasks. """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForSequenceClassification(BertPreTrainedModel): + config_class = LongformerConfig + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config) + self.classifier = LongformerClassificationHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096") + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + if global_attention_mask is None: + logger.info("Initializing global attention on CLS token...") + global_attention_mask = torch.zeros_like(input_ids) + # global attention on cls token + global_attention_mask[:, 0] = 1 + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + +class LongformerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, hidden_states, **kwargs): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + output = self.out_proj(hidden_states) + return output + + +@add_start_docstrings( + """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForQuestionAnswering(BertPreTrainedModel): + config_class = LongformerConfig + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + >>> from transformers import LongformerTokenizer, LongformerForQuestionAnswering + >>> import torch + + >>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") + >>> model = LongformerForQuestionAnswering.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> encoding = tokenizer(question, text, return_tensors="pt") + >>> input_ids = encoding["input_ids"] + + >>> # default is local attention everywhere + >>> # the forward method will automatically set global attention on question tokens + >>> attention_mask = encoding["attention_mask"] + + >>> start_scores, end_scores = model(input_ids, attention_mask=attention_mask) + >>> all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) + + >>> answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1] + >>> answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token + + """ + + # set global attention on question tokens + if global_attention_mask is None: + logger.info("Initializing global attention on question tokens...") + # put global attention on all tokens until `config.sep_token_id` is reached + global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Longformer Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForTokenClassification(BertPreTrainedModel): + config_class = LongformerConfig + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096") + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : + Classification loss. + scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Longformer Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForMultipleChoice(BertPreTrainedModel): + config_class = LongformerConfig + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + + self.longformer = LongformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096") + def forward( + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + global_attention_mask=None, + labels=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + loss (:obj:`torch.FloatTensor`` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + # set global attention on question tokens + if global_attention_mask is None: + logger.info("Initializing global attention on multiple choice...") + # put global attention on all tokens after `config.sep_token_id` + global_attention_mask = torch.stack( + [ + _compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False) + for i in range(num_choices) + ], + dim=1, + ) + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_global_attention_mask = ( + global_attention_mask.view(-1, global_attention_mask.size(-1)) + if global_attention_mask is not None + else None + ) + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.longformer( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + global_attention_mask=flat_global_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 5562701cd10d..57060f8ff279 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -117,28 +117,30 @@ def __init__(self, config, layer_id): self.one_sided_attn_window_size = attention_window // 2 @staticmethod - def pad_and_transpose_last_two_dims(hidden_states_padded, padding): - """Convert diagonals into columns or columns into diagonals depending on `padding`""" + def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """pads rows and then flips rows and columns""" hidden_states_padded = F.pad( hidden_states_padded, padding ) # padding value is not important because it will be overwritten - hidden_states_padded = hidden_states_padded.view( - *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) - ) + hidden_states_padded = hidden_states_padded.transpose(-1, -2) return hidden_states_padded @staticmethod def _pad_by_window_overlap_except_last_row(chunked_hidden_states): - """shift every row 1 step to right converting columns into diagonals""" + """shift every row 1 step right, converting columns into diagonals""" total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() chunked_hidden_states = F.pad( chunked_hidden_states, (0, window_overlap + 1) - ) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten - chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, -1) # B x C x ML+MM+M - chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap] # B x C x ML+MM + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, -1 + ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap chunked_hidden_states = chunked_hidden_states.view( total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim - ) # B x C, M x L+M + ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @@ -154,7 +156,7 @@ def _chunk(hidden_states, window_overlap): hidden_states.size(2), ) - # use `as_strided` to make the chunks overlap with an overlap size = w + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap chunk_size = list(hidden_states.size()) chunk_size[1] = chunk_size[1] * 2 - 1 @@ -174,7 +176,7 @@ def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tenso ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): - """Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. + """Matrix multiplication of query and key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an overlap of size window_overlap""" batch_size, seq_len, num_heads, head_dim = query.size() @@ -193,26 +195,26 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso chunked_key = self._chunk(key, window_overlap) # matrix multipication - # bcxd: batch_size * num_heads x chunks x 2w x head_dim - # bcyd: batch_size * num_heads x chunks x 2w x head_dim - # bcxy: batch_size * num_heads x chunks x 2w x 2w + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply # convert diagonals into columns - diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims( + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( chunked_attention_scores, padding=(0, 0, 0, 1) ) - # allocate space for the overall attention matrix where the chunks are compined. The last dimension - # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to - # w previous words). The following column is attention score from each word to itself, then - # followed by w columns for the upper triangle. + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) ) - # copy parts from diagonal_chunked_attention_scores into the compined matrix of attentions + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions # - copying the main diagonal and the upper triangle diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ :, :, :window_overlap, : window_overlap + 1 @@ -239,8 +241,8 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso def _sliding_chunks_matmul_attn_probs_value( self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int ): - """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. It is expecting the same output - format from _sliding_chunks_query_key_matmul""" + """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. + Returned tensor will be of the same shape as `attn_probs`""" batch_size, seq_len, num_heads, head_dim = value.size() assert seq_len % (window_overlap * 2) == 0 assert attn_probs.size()[:3] == value.size()[:3] @@ -292,7 +294,7 @@ def forward( # is index masked or global attention is_index_masked = attention_mask < 0 is_index_global_attn = attention_mask > 0 - is_global_attn = not torch.all(~is_index_global_attn).item() + is_global_attn = any(is_index_global_attn.flatten()) hidden_states = hidden_states.transpose(0, 1) @@ -317,7 +319,7 @@ def forward( query_vectors, key_vectors, self.one_sided_attn_window_size ) - # from (batch_size x seq_len) to (batch_size x seq_len x num_heads x hidden_size) + # values to pad for attention probs remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) # cast to fp32/fp16 then replace 1's with -inf @@ -400,7 +402,7 @@ def forward( # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation if is_global_attn: - global_attn_output = self.compute_global_attn_output_from_hidden( + global_attn_output = self._compute_global_attn_output_from_hidden( hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, @@ -440,7 +442,7 @@ def forward( @staticmethod def _get_global_attn_indices(is_index_global_attn): - """ compute global attn indices required through out forward fn """ + """ compute global attn indices required throughout forward pass """ # helper variable num_global_attn_indices = is_index_global_attn.long().sum(dim=1) @@ -526,7 +528,7 @@ def _compute_attn_output_with_global_indices( ) return attn_output_only_global + attn_output_without_global - def compute_global_attn_output_from_hidden( + def _compute_global_attn_output_from_hidden( self, hidden_states, max_num_global_attn_indices, From 90d2aa673df6159afd7b92402f9f2c2ea12611f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Jul 2020 15:28:08 +0000 Subject: [PATCH 19/19] fix transpose --- + | 1562 ----------------------- src/transformers/modeling_longformer.py | 4 +- 2 files changed, 3 insertions(+), 1563 deletions(-) delete mode 100644 + diff --git a/+ b/+ deleted file mode 100644 index 0f133eb7db14..000000000000 --- a/+ +++ /dev/null @@ -1,1562 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Longformer model. """ - -import logging -import math -import warnings - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss, MSELoss -from torch.nn import functional as F - -from .configuration_longformer import LongformerConfig -from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput -from .modeling_roberta import RobertaEmbeddings, RobertaLMHead -from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer - - -logger = logging.getLogger(__name__) - -_TOKENIZER_FOR_DOC = "LongformerTokenizer" - -LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "allenai/longformer-base-4096", - "allenai/longformer-large-4096", - "allenai/longformer-large-4096-finetuned-triviaqa", - "allenai/longformer-base-4096-extra.pos.embd.only", - "allenai/longformer-large-4096-extra.pos.embd.only", - # See all Longformer models at https://huggingface.co/models?filter=longformer -] - - -def _get_question_end_index(input_ids, sep_token_id): - """ - Computes the index of the first occurance of `sep_token_id`. - """ - - sep_token_indices = (input_ids == sep_token_id).nonzero() - batch_size = input_ids.shape[0] - - assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" - assert ( - sep_token_indices.shape[0] == 3 * batch_size - ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error." - - return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] - - -def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True): - """ - Computes global attention mask by putting attention on all tokens - before `sep_token_id` if `before_sep_token is True` else after - `sep_token_id`. - """ - - question_end_index = _get_question_end_index(input_ids, sep_token_id) - question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 - # bool attention mask with True in locations of global attention - attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) - if before_sep_token is True: - attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8) - else: - # last token is separation token and should not be counted and in the middle are two separation tokens - attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * ( - attention_mask.expand_as(input_ids) < input_ids.shape[-1] - ).to(torch.uint8) - - return attention_mask - - -class LongformerSelfAttention(nn.Module): - def __init__(self, config, layer_id): - super().__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - self.num_heads = config.num_attention_heads - self.head_dim = int(config.hidden_size / config.num_attention_heads) - self.embed_dim = config.hidden_size - - self.query = nn.Linear(config.hidden_size, self.embed_dim) - self.key = nn.Linear(config.hidden_size, self.embed_dim) - self.value = nn.Linear(config.hidden_size, self.embed_dim) - - # separate projection layers for tokens with global attention - self.query_global = nn.Linear(config.hidden_size, self.embed_dim) - self.key_global = nn.Linear(config.hidden_size, self.embed_dim) - self.value_global = nn.Linear(config.hidden_size, self.embed_dim) - - self.dropout = config.attention_probs_dropout_prob - - self.layer_id = layer_id - attention_window = config.attention_window[self.layer_id] - assert ( - attention_window % 2 == 0 - ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" - assert ( - attention_window > 0 - ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" - - self.one_sided_attn_window_size = attention_window // 2 - - @staticmethod - def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): - """pads rows and then flips rows and columns""" - hidden_states_padded = F.pad( - hidden_states_padded, padding - ) # padding value is not important because it will be overwritten - hidden_states_padded = hidden_states_padded.transpose(-1, -2) - return hidden_states_padded - - @staticmethod - def _pad_by_window_overlap_except_last_row(chunked_hidden_states): - """shift every row 1 step right, converting columns into diagonals""" - total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() - chunked_hidden_states = F.pad( - chunked_hidden_states, (0, window_overlap + 1) - ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten - chunked_hidden_states = chunked_hidden_states.view(total_num_heads, num_chunks, -1) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap - chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap - chunked_hidden_states = chunked_hidden_states.view( - total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim - ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap - chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] - return chunked_hidden_states - - @staticmethod - def _chunk(hidden_states, window_overlap): - """convert into overlapping chunkings. Chunk size = 2w, overlap size = w""" - - # non-overlapping chunks of size = 2w - hidden_states = hidden_states.view( - hidden_states.size(0), - hidden_states.size(1) // (window_overlap * 2), - window_overlap * 2, - hidden_states.size(2), - ) - - # use `as_strided` to make the chunks overlap with an overlap size = window_overlap - chunk_size = list(hidden_states.size()) - chunk_size[1] = chunk_size[1] * 2 - 1 - - chunk_stride = list(hidden_states.stride()) - chunk_stride[1] = chunk_stride[1] // 2 - return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - - def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor: - beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) - beginning_mask = beginning_mask_2d[None, :, None, :] - ending_mask = beginning_mask.flip(dims=(1, 3)) - beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] - beginning_mask = beginning_mask.expand(beginning_input.size()) - beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 - ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] - ending_mask = ending_mask.expand(ending_input.size()) - ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 - - def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): - """Matrix multiplication of query and key tensors using with a sliding window attention pattern. - This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) - with an overlap of size window_overlap""" - batch_size, seq_len, num_heads, head_dim = query.size() - assert ( - seq_len % (window_overlap * 2) == 0 - ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" - assert query.size() == key.size() - - chunks_count = seq_len // window_overlap - 1 - - # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - - chunked_query = self._chunk(query, window_overlap) - chunked_key = self._chunk(key, window_overlap) - - # matrix multipication - # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap - chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply - - # convert diagonals into columns - diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( - chunked_attention_scores, padding=(0, 0, 0, 1) - ) - - # allocate space for the overall attention matrix where the chunks are combined. The last dimension - # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to - # window_overlap previous words). The following column is attention score from each word to itself, then - # followed by window_overlap columns for the upper triangle. - - diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( - (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) - ) - - # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions - # - copying the main diagonal and the upper triangle - diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ - :, :, :window_overlap, : window_overlap + 1 - ] - diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ - :, -1, window_overlap:, : window_overlap + 1 - ] - # - copying the lower triangle - diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ - :, :, -(window_overlap + 1) : -1, window_overlap + 1 : - ] - diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ - :, 0, : window_overlap - 1, 1 - window_overlap : - ] - - # separate batch_size and num_heads dimensions again - diagonal_attention_scores = diagonal_attention_scores.view( - batch_size, num_heads, seq_len, 2 * window_overlap + 1 - ).transpose(2, 1) - - self._mask_invalid_locations(diagonal_attention_scores, window_overlap) - return diagonal_attention_scores - - def _sliding_chunks_matmul_attn_probs_value( - self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int - ): - """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. - Returned tensor will be of the same shape as `attn_probs`""" - batch_size, seq_len, num_heads, head_dim = value.size() - assert seq_len % (window_overlap * 2) == 0 - assert attn_probs.size()[:3] == value.size()[:3] - assert attn_probs.size(3) == 2 * window_overlap + 1 - chunks_count = seq_len // window_overlap - 1 - # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap - chunked_attn_probs = attn_probs.transpose(1, 2).reshape( - batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 - ) - - # group batch_size and num_heads dimensions into one - value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - - # pad seq_len with w at the beginning of the sequence and another window overlap at the end - padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) - - # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap - chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) - chunked_value_stride = padded_value.stride() - chunked_value_stride = ( - chunked_value_stride[0], - window_overlap * chunked_value_stride[1], - chunked_value_stride[1], - chunked_value_stride[2], - ) - chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) - - chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs) - - context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) - return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) - - def forward( - self, hidden_states, attention_mask=None, output_attentions=False, - ): - """ - LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. - Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer. - - The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to - -ve: no attention - 0: local attention - +ve: global attention - - """ - - attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) - - # is index masked or global attention - is_index_masked = attention_mask < 0 - is_index_global_attn = attention_mask > 0 - is_global_attn = any(is_index_global_attn).item() - - hidden_states = hidden_states.transpose(0, 1) - - # project hidden states - query_vectors = self.query(hidden_states) - key_vectors = self.key(hidden_states) - value_vectors = self.value(hidden_states) - - seq_len, batch_size, embed_dim = hidden_states.size() - assert ( - embed_dim == self.embed_dim - ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" - - # normalize query - query_vectors /= math.sqrt(self.head_dim) - - query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - - # attn_probs = (batch_size, seq_len, num_heads, window*2+1) - attn_scores = self._sliding_chunks_query_key_matmul( - query_vectors, key_vectors, self.one_sided_attn_window_size - ) - - # values to pad for attention probs - remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) - - # cast to fp32/fp16 then replace 1's with -inf - float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( - remove_from_windowed_attention_mask, -10000.0 - ) - # diagonal mask with zeros everywhere and -inf inplace of padding - diagonal_mask = self._sliding_chunks_query_key_matmul( - float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size - ) - - # pad local attention probs - attn_scores += diagonal_mask - - assert list(attn_scores.size()) == [ - batch_size, - seq_len, - self.num_heads, - self.one_sided_attn_window_size * 2 + 1, - ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" - - # compute local attention probs from global attention keys and contact over window dim - if is_global_attn: - # compute global attn indices required through out forward fn - ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) = self._get_global_attn_indices(is_index_global_attn) - # calculate global attn probs from global key - global_key_attn_scores = self._concat_with_global_key_attn_probs( - query_vectors=query_vectors, - key_vectors=key_vectors, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - ) - # concat to attn_probs - # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) - - # free memory - del global_key_attn_scores - - attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attn_probs = attn_probs_fp32.type_as(attn_scores) - - # free memory - del attn_probs_fp32 - - # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) - - # apply dropout - attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) - - value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - - # compute local attention output with global attention value and add - if is_global_attn: - # compute sum of global and local attn - attn_output = self._compute_attn_output_with_global_indices( - value_vectors=value_vectors, - attn_probs=attn_probs, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - ) - else: - # compute local attn only - attn_output = self._sliding_chunks_matmul_attn_probs_value( - attn_probs, value_vectors, self.one_sided_attn_window_size - ) - - assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" - attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() - - # compute value for global attention and overwrite to attention output - # TODO: remove the redundant computation - if is_global_attn: - global_attn_output = self.compute_global_attn_output_from_hidden( - hidden_states=hidden_states, - max_num_global_attn_indices=max_num_global_attn_indices, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - is_index_masked=is_index_masked, - ) - - # get only non zero global attn output - nonzero_global_attn_output = global_attn_output[ - is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] - ] - # overwrite values with global attention - attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( - len(is_local_index_global_attn_nonzero[0]), -1 - ) - - attn_output = attn_output.transpose(0, 1) - - if output_attentions: - if is_global_attn: - # With global attention, return global attention probabilities only - # batch_size x num_heads x max_num_global_attention_tokens x sequence_length - # which is the attention weights from tokens with global attention to all tokens - # It doesn't not return local attention - # In case of variable number of global attantion in the rows of a batch, - # attn_probs are padded with -10000.0 attention scores - attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) - else: - # without global attention, return local attention probabilities - # batch_size x num_heads x sequence_length x window_size - # which is the attention weights of every token attending to its neighbours - attn_probs = attn_probs.permute(0, 2, 1, 3) - - outputs = (attn_output, attn_probs) if output_attentions else (attn_output,) - return outputs - - @staticmethod - def _get_global_attn_indices(is_index_global_attn): - """ compute global attn indices required throughout forward pass """ - # helper variable - num_global_attn_indices = is_index_global_attn.long().sum(dim=1) - - # max number of global attn indices in batch - max_num_global_attn_indices = num_global_attn_indices.max() - - # indices of global attn - is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) - - # helper variable - is_local_index_global_attn = torch.arange( - max_num_global_attn_indices, device=is_index_global_attn.device - ) < num_global_attn_indices.unsqueeze(dim=-1) - - # location of the non-padding values within global attention indices - is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) - - # location of the padding values within global attention indices - is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) - return ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) - - def _concat_with_global_key_attn_probs( - self, - key_vectors, - query_vectors, - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ): - batch_size = key_vectors.shape[0] - - # create only global key vectors - key_vectors_only_global = key_vectors.new_zeros( - batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim - ) - key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] - # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) - attn_probs_from_global_key[ - is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] - ] = -10000.0 - return attn_probs_from_global_key - - def _compute_attn_output_with_global_indices( - self, - value_vectors, - attn_probs, - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - ): - batch_size = attn_probs.shape[0] - - # cut local attn probs to global only - attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) - # get value vectors for global only - value_vectors_only_global = value_vectors.new_zeros( - batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim - ) - value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] - - # use `matmul` because `einsum` crashes sometimes with fp16 - # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - # compute attn output only global - attn_output_only_global = torch.matmul( - attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) - ).transpose(1, 2) - - # reshape attn probs - attn_probs_without_global = attn_probs.narrow( - -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices - ).contiguous() - - # compute attn output with global - attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( - attn_probs_without_global, value_vectors, self.one_sided_attn_window_size - ) - return attn_output_only_global + attn_output_without_global - - def compute_global_attn_output_from_hidden( - self, - hidden_states, - max_num_global_attn_indices, - is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - is_index_masked, - ): - seq_len, batch_size = hidden_states.shape[:2] - - # prepare global hidden states - global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) - global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ - is_index_global_attn_nonzero[::-1] - ] - - # global key, query, value - global_query_vectors_only_global = self.query_global(global_attn_hidden_states) - global_key_vectors = self.key_global(hidden_states) - global_value_vectors = self.value_global(hidden_states) - - # normalize - global_query_vectors_only_global /= math.sqrt(self.head_dim) - - # reshape - global_query_vectors_only_global = ( - global_query_vectors_only_global.contiguous() - .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) - .transpose(0, 1) - ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) - global_key_vectors = ( - global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seq_len, head_dim) - global_value_vectors = ( - global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seq_len, head_dim) - - # compute attn scores - global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) - - assert list(global_attn_scores.size()) == [ - batch_size * self.num_heads, - max_num_global_attn_indices, - seq_len, - ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}." - - global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) - - global_attn_scores[ - is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : - ] = -10000.0 - - global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) - - global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) - - # compute global attn probs - global_attn_probs_float = F.softmax( - global_attn_scores, dim=-1, dtype=torch.float32 - ) # use fp32 for numerical stability - - global_attn_probs = F.dropout( - global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training - ) - - # global attn output - global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) - - assert list(global_attn_output.size()) == [ - batch_size * self.num_heads, - max_num_global_attn_indices, - self.head_dim, - ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}." - - global_attn_output = global_attn_output.view( - batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim - ) - return global_attn_output - - -class LongformerAttention(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.self = LongformerSelfAttention(config, layer_id) - self.output = BertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, hidden_states, attention_mask=None, output_attentions=False, - ): - self_outputs = self.self(hidden_states, attention_mask, output_attentions,) - attn_output = self.output(self_outputs[0], hidden_states) - outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -class LongformerLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.attention = LongformerAttention(config, layer_id) - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - def forward( - self, hidden_states, attention_mask=None, output_attentions=False, - ): - self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) - attn_output = self_attn_outputs[0] - outputs = self_attn_outputs[1:] # add self attentions if we output attention weights - - intermediate_output = self.intermediate(attn_output) - layer_output = self.output(intermediate_output, attn_output) - outputs = (layer_output,) + outputs - return outputs - - -class LongformerEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) - - def forward( - self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, - ): - all_hidden_states = () - all_attentions = () - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if getattr(self.config, "gradient_checkpointing", False): - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, attention_mask, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = (hidden_states,) - if output_hidden_states: - outputs = outputs + (all_hidden_states,) - if output_attentions: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) - - -class LongformerPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained - models. - """ - - config_class = LongformerConfig - base_model_prefix = "longformer" - - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -LONGFORMER_START_DOCSTRING = r""" - - This model is a PyTorch `torch.nn.Module `__ sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config (:class:`~transformers.LongformerConfig`): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. -""" - -LONGFORMER_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. - See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.__call__` for details. - - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on padding token indices. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - - `What are attention masks? <../glossary.html#attention-mask>`__ - - global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): - Mask to decide the attention given on each token, local attention or global attenion. - Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for - task-specific finetuning because it makes the model more flexible at representing the task. For example, - for classification, the token should be given global attention. For QA, all question tokens should also have - global attention. Please refer to the `Longformer paper `__ for more details. - Mask values selected in ``[0, 1]``: - ``0`` for local attention (a sliding window attention), - ``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): - Segment token indices to indicate first and second portions of the inputs. - Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` - corresponds to a `sentence B` token - - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range ``[0, config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): - If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. -""" - - -@add_start_docstrings( - "The bare Longformer Model outputting raw hidden-states without any specific head on top.", - LONGFORMER_START_DOCSTRING, -) -class LongformerModel(LongformerPreTrainedModel): - """ - This class overrides :class:`~transformers.RobertaModel` to provide the ability to process - long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer - `__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer selfattention - combines a local (sliding window) and global attention to extend to long documents without the O(n^2) increase in - memory and compute. - - The selfattention module `LongformerSelfAttention` implemented here supports the combination of local and - global attention but it lacks support for autoregressive attention and dilated attention. Autoregressive - and dilated attention are more relevant for autoregressive language modeling than finetuning on downstream - tasks. Future release will add support for autoregressive attention, but the support for dilated attention - requires a custom CUDA kernel to be memory and compute efficient. - - """ - - config_class = LongformerConfig - base_model_prefix = "longformer" - - def __init__(self, config): - super().__init__(config) - self.config = config - - if isinstance(config.attention_window, int): - assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" - assert config.attention_window > 0, "`config.attention_window` has to be positive" - config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer - else: - assert len(config.attention_window) == config.num_hidden_layers, ( - "`len(config.attention_window)` should equal `config.num_hidden_layers`. " - f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" - ) - - self.embeddings = RobertaEmbeddings(config) - self.encoder = LongformerEncoder(config) - self.pooler = BertPooler(config) - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def _pad_to_window_size( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor, - position_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - pad_token_id: int, - ): - """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" - # padding - attention_window = ( - self.config.attention_window - if isinstance(self.config.attention_window, int) - else max(self.config.attention_window) - ) - - assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" - input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape - batch_size, seq_len = input_shape[:2] - - padding_len = (attention_window - seq_len % attention_window) % attention_window - if padding_len > 0: - logger.info( - "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( - seq_len, seq_len + padding_len, attention_window - ) - ) - if input_ids is not None: - input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) - if position_ids is not None: - # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings - position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) - if inputs_embeds is not None: - input_ids_padding = inputs_embeds.new_full( - (batch_size, padding_len), self.config.pad_token_id, dtype=torch.long, - ) - inputs_embeds_padding = self.embeddings(input_ids_padding) - inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) - - attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens - token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 - - return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds - - def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): - # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) - # (global_attention_mask + 1) => 1 for local attention, 2 for global attention - # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention - if attention_mask is not None: - attention_mask = attention_mask * (global_attention_mask + 1) - else: - # simply use `global_attention_mask` as `attention_mask` - # if no `attention_mask` is given - attention_mask = global_attention_mask + 1 - return attention_mask - - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) - def forward( - self, - input_ids=None, - attention_mask=None, - global_attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - ): - r""" - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: - prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - >>> import torch - >>> from transformers import LongformerModel, LongformerTokenizer - - >>> model = LongformerModel.from_pretrained('allenai/longformer-base-4096') - >>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') - - >>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document - >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 - - >>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention - >>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention - >>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example, - ... # classification: the token - ... # QA: question tokens - ... # LM: potentially on the beginning of sentences and paragraphs - >>> sequence_output, pooled_output = model(input_ids, attention_mask=attention_mask) - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # merge `global_attention_mask` and `attention_mask` - if global_attention_mask is not None: - attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) - - padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - pad_token_id=self.config.pad_token_id, - ) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) - - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds - ) - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) - - outputs = (sequence_output, pooled_output,) + encoder_outputs[ - 1: - ] # add hidden_states and attentions if they are here - - # undo padding - if padding_len > 0: - # `output` has the following tensors: sequence_output, pooled_output, (hidden_states), (attentions) - # `sequence_output`: unpad because the calling function is expecting a length == input_ids.size(1) - # `pooled_output`: independent of the sequence length - # `hidden_states`: mainly used for debugging and analysis, so keep the padding - # `attentions`: mainly used for debugging and analysis, so keep the padding - outputs = outputs[0][:, :-padding_len], *outputs[1:] - - return outputs - - -@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) -class LongformerForMaskedLM(LongformerPreTrainedModel): - config_class = LongformerConfig - base_model_prefix = "longformer" - - def __init__(self, config): - super().__init__(config) - - self.longformer = LongformerModel(config) - self.lm_head = RobertaLMHead(config) - - self.init_weights() - - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) - def forward( - self, - input_ids=None, - attention_mask=None, - global_attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - **kwargs - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Labels for computing the masked language modeling loss. - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` - kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: - masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: - Masked language modeling loss. - prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - >>> import torch - >>> from transformers import LongformerForMaskedLM, LongformerTokenizer - - >>> model = LongformerForMaskedLM.from_pretrained('allenai/longformer-base-4096') - >>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') - - >>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document - >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 - - >>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM - ... # check ``LongformerModel.forward`` for more details how to set `attention_mask` - >>> loss, prediction_scores = model(input_ids, attention_mask=attention_mask, labels=input_ids) - """ - - if "masked_lm_labels" in kwargs: - warnings.warn( - "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", - DeprecationWarning, - ) - labels = kwargs.pop("masked_lm_labels") - assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." - - outputs = self.longformer( - input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here - - if labels is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - outputs = (masked_lm_loss,) + outputs - - return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) - - -@add_start_docstrings( - """Longformer Model transformer with a sequence classification/regression head on top (a linear layer - on top of the pooled output) e.g. for GLUE tasks. """, - LONGFORMER_START_DOCSTRING, -) -class LongformerForSequenceClassification(BertPreTrainedModel): - config_class = LongformerConfig - base_model_prefix = "longformer" - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.longformer = LongformerModel(config) - self.classifier = LongformerClassificationHead(config) - - self.init_weights() - - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) - @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096") - def forward( - self, - input_ids=None, - attention_mask=None, - global_attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): - Labels for computing the sequence classification/regression loss. - Indices should be in :obj:`[0, ..., config.num_labels - 1]`. - If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - if global_attention_mask is None: - logger.info("Initializing global attention on CLS token...") - global_attention_mask = torch.zeros_like(input_ids) - # global attention on cls token - global_attention_mask[:, 0] = 1 - - outputs = self.longformer( - input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output) - - outputs = (logits,) + outputs[2:] - if labels is not None: - if self.num_labels == 1: - # We are doing regression - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - outputs = (loss,) + outputs - - return outputs # (loss), logits, (hidden_states), (attentions) - - -class LongformerClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.out_proj = nn.Linear(config.hidden_size, config.num_labels) - - def forward(self, hidden_states, **kwargs): - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - output = self.out_proj(hidden_states) - return output - - -@add_start_docstrings( - """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of - the hidden-states output to compute `span start logits` and `span end logits`). """, - LONGFORMER_START_DOCSTRING, -) -class LongformerForQuestionAnswering(BertPreTrainedModel): - config_class = LongformerConfig - base_model_prefix = "longformer" - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.longformer = LongformerModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - self.init_weights() - - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) - def forward( - self, - input_ids=None, - attention_mask=None, - global_attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - ): - r""" - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). - Position outside of the sequence are not taken into account for computing the loss. - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). - Position outside of the sequence are not taken into account for computing the loss. - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): - Span-start scores (before SoftMax). - end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): - Span-end scores (before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - >>> from transformers import LongformerTokenizer, LongformerForQuestionAnswering - >>> import torch - - >>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") - >>> model = LongformerForQuestionAnswering.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") - - >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - >>> encoding = tokenizer(question, text, return_tensors="pt") - >>> input_ids = encoding["input_ids"] - - >>> # default is local attention everywhere - >>> # the forward method will automatically set global attention on question tokens - >>> attention_mask = encoding["attention_mask"] - - >>> start_scores, end_scores = model(input_ids, attention_mask=attention_mask) - >>> all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) - - >>> answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1] - >>> answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token - - """ - - # set global attention on question tokens - if global_attention_mask is None: - logger.info("Initializing global attention on question tokens...") - # put global attention on all tokens until `config.sep_token_id` is reached - global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) - - outputs = self.longformer( - input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - outputs = (start_logits, end_logits,) + outputs[2:] - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - outputs = (total_loss,) + outputs - - return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) - - -@add_start_docstrings( - """Longformer Model with a token classification head on top (a linear layer on top of - the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - LONGFORMER_START_DOCSTRING, -) -class LongformerForTokenClassification(BertPreTrainedModel): - config_class = LongformerConfig - base_model_prefix = "longformer" - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.longformer = LongformerModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - self.init_weights() - - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) - @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096") - def forward( - self, - input_ids=None, - attention_mask=None, - global_attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Labels for computing the token classification loss. - Indices should be in ``[0, ..., config.num_labels - 1]``. - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : - Classification loss. - scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) - Classification scores (before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - outputs = self.longformer( - input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here - - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) - ) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - outputs = (loss,) + outputs - - return outputs # (loss), scores, (hidden_states), (attentions) - - -@add_start_docstrings( - """Longformer Model with a multiple choice classification head on top (a linear layer on top of - the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, - LONGFORMER_START_DOCSTRING, -) -class LongformerForMultipleChoice(BertPreTrainedModel): - config_class = LongformerConfig - base_model_prefix = "longformer" - - def __init__(self, config): - super().__init__(config) - - self.longformer = LongformerModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - - self.init_weights() - - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) - @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096") - def forward( - self, - input_ids=None, - token_type_ids=None, - attention_mask=None, - global_attention_mask=None, - labels=None, - position_ids=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): - Labels for computing the multiple choice classification loss. - Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension - of the input tensors. (see `input_ids` above) - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: - loss (:obj:`torch.FloatTensor`` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): - Classification loss. - classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): - `num_choices` is the second dimension of the input tensors. (see `input_ids` above). - - Classification scores (before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - - # set global attention on question tokens - if global_attention_mask is None: - logger.info("Initializing global attention on multiple choice...") - # put global attention on all tokens after `config.sep_token_id` - global_attention_mask = torch.stack( - [ - _compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False) - for i in range(num_choices) - ], - dim=1, - ) - - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None - flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - flat_global_attention_mask = ( - global_attention_mask.view(-1, global_attention_mask.size(-1)) - if global_attention_mask is not None - else None - ) - flat_inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None - else None - ) - - outputs = self.longformer( - flat_input_ids, - position_ids=flat_position_ids, - token_type_ids=flat_token_type_ids, - attention_mask=flat_attention_mask, - global_attention_mask=flat_global_attention_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - outputs = (loss,) + outputs - - return outputs # (loss), reshaped_logits, (hidden_states), (attentions) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 57060f8ff279..8ea884ddc414 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -122,7 +122,9 @@ def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): hidden_states_padded = F.pad( hidden_states_padded, padding ) # padding value is not important because it will be overwritten - hidden_states_padded = hidden_states_padded.transpose(-1, -2) + hidden_states_padded = hidden_states_padded.view( + *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) + ) return hidden_states_padded @staticmethod