diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 9d869e73a1c5..8ea884ddc414 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -25,8 +25,9 @@ from .configuration_longformer import LongformerConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_bert import BertPreTrainedModel -from .modeling_roberta import RobertaLMHead, RobertaModel +from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput +from .modeling_roberta import RobertaEmbeddings, RobertaLMHead +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer logger = logging.getLogger(__name__) @@ -113,137 +114,171 @@ def __init__(self, config, layer_id): attention_window > 0 ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" - self.one_sided_attention_window_size = attention_window // 2 + self.one_sided_attn_window_size = attention_window // 2 @staticmethod - def _skew(x, direction): - """Convert diagonals into columns (or columns into diagonals depending on `direction`""" - x_padded = F.pad(x, direction) # padding value is not important because it will be overwritten - x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2)) - return x_padded + def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """pads rows and then flips rows and columns""" + hidden_states_padded = F.pad( + hidden_states_padded, padding + ) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.view( + *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) + ) + return hidden_states_padded @staticmethod - def _skew2(x): - """shift every row 1 step to right converting columns into diagonals""" - # X = B x C x M x L - B, C, M, L = x.size() - x = F.pad(x, (0, M + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten - x = x.view(B, C, -1) # B x C x ML+MM+M - x = x[:, :, :-M] # B x C x ML+MM - x = x.view(B, C, M, M + L) # B x C, M x L+M - x = x[:, :, :, :-1] - return x + def _pad_by_window_overlap_except_last_row(chunked_hidden_states): + """shift every row 1 step right, converting columns into diagonals""" + total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() + chunked_hidden_states = F.pad( + chunked_hidden_states, (0, window_overlap + 1) + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, -1 + ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states @staticmethod - def _chunk(x, w): + def _chunk(hidden_states, window_overlap): """convert into overlapping chunkings. Chunk size = 2w, overlap size = w""" # non-overlapping chunks of size = 2w - x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2)) + hidden_states = hidden_states.view( + hidden_states.size(0), + hidden_states.size(1) // (window_overlap * 2), + window_overlap * 2, + hidden_states.size(2), + ) - # use `as_strided` to make the chunks overlap with an overlap size = w - chunk_size = list(x.size()) + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) chunk_size[1] = chunk_size[1] * 2 - 1 - chunk_stride = list(x.stride()) + chunk_stride = list(hidden_states.stride()) chunk_stride[1] = chunk_stride[1] // 2 - return x.as_strided(size=chunk_size, stride=chunk_stride) + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - def _mask_invalid_locations(self, input_tensor, w) -> torch.Tensor: - affected_seqlen = w - beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0]) + def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor: + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) beginning_mask = beginning_mask_2d[None, :, None, :] ending_mask = beginning_mask.flip(dims=(1, 3)) - beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1] + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_mask = beginning_mask.expand(beginning_input.size()) beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 - ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :] + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_mask = ending_mask.expand(ending_input.size()) ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 - def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int): - """Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. + def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): + """Matrix multiplication of query and key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) - with an overlap of size w""" - batch_size, seqlen, num_heads, head_dim = q.size() - assert seqlen % (w * 2) == 0, f"Sequence length should be multiple of {w * 2}. Given {seqlen}" - assert q.size() == k.size() + with an overlap of size window_overlap""" + batch_size, seq_len, num_heads, head_dim = query.size() + assert ( + seq_len % (window_overlap * 2) == 0 + ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + assert query.size() == key.size() - chunks_count = seqlen // w - 1 + chunks_count = seq_len // window_overlap - 1 - # group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2 - q = q.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) - k = k.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - chunk_q = self._chunk(q, w) - chunk_k = self._chunk(k, w) + chunked_query = self._chunk(query, window_overlap) + chunked_key = self._chunk(key, window_overlap) # matrix multipication - # bcxd: batch_size * num_heads x chunks x 2w x head_dim - # bcyd: batch_size * num_heads x chunks x 2w x head_dim - # bcxy: batch_size * num_heads x chunks x 2w x 2w - chunk_attn = torch.einsum("bcxd,bcyd->bcxy", (chunk_q, chunk_k)) # multiply + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap + chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply # convert diagonals into columns - diagonal_chunk_attn = self._skew(chunk_attn, direction=(0, 0, 0, 1)) + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + chunked_attention_scores, padding=(0, 0, 0, 1) + ) - # allocate space for the overall attention matrix where the chunks are compined. The last dimension - # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to - # w previous words). The following column is attention score from each word to itself, then - # followed by w columns for the upper triangle. + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. - diagonal_attn = diagonal_chunk_attn.new_empty((batch_size * num_heads, chunks_count + 1, w, w * 2 + 1)) + diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) - # copy parts from diagonal_chunk_attn into the compined matrix of attentions + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions # - copying the main diagonal and the upper triangle - diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, : w + 1] - diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, : w + 1] + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] # - copying the lower triangle - diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, -(w + 1) : -1, w + 1 :] - diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, : w - 1, 1 - w :] + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] # separate batch_size and num_heads dimensions again - diagonal_attn = diagonal_attn.view(batch_size, num_heads, seqlen, 2 * w + 1).transpose(2, 1) - - self._mask_invalid_locations(diagonal_attn, w) - return diagonal_attn - - def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int): - """Same as _sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output - format from _sliding_chunks_matmul_qk""" - batch_size, seqlen, num_heads, head_dim = v.size() - assert seqlen % (w * 2) == 0 - assert prob.size()[:3] == v.size()[:3] - assert prob.size(3) == 2 * w + 1 - chunks_count = seqlen // w - 1 - # group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size 2w - chunk_prob = prob.transpose(1, 2).reshape(batch_size * num_heads, seqlen // w, w, 2 * w + 1) + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seq_len, 2 * window_overlap + 1 + ).transpose(2, 1) - # group batch_size and num_heads dimensions into one - v = v.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores - # pad seqlen with w at the beginning of the sequence and another w at the end - padded_v = F.pad(v, (0, 0, w, w), value=-1) + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + ): + """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. + Returned tensor will be of the same shape as `attn_probs`""" + batch_size, seq_len, num_heads, head_dim = value.size() + assert seq_len % (window_overlap * 2) == 0 + assert attn_probs.size()[:3] == value.size()[:3] + assert attn_probs.size(3) == 2 * window_overlap + 1 + chunks_count = seq_len // window_overlap - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( + batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + ) - # chunk padded_v into chunks of size 3w and an overlap of size w - chunk_v_size = (batch_size * num_heads, chunks_count + 1, 3 * w, head_dim) - chunk_v_stride = padded_v.stride() - chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2] - chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride) + # group batch_size and num_heads dimensions into one + value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) - skewed_prob = self._skew2(chunk_prob) + chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs) - context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunk_v)) - return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2) + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions=False, + self, hidden_states, attention_mask=None, output_attentions=False, ): """ LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. @@ -254,187 +289,449 @@ def forward( 0: local attention +ve: global attention - `encoder_hidden_states` and `encoder_attention_mask` are not supported and should be None """ - # TODO: add support for `encoder_hidden_states` and `encoder_attention_mask` - assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" - assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None" - if attention_mask is not None: - attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) - key_padding_mask = attention_mask < 0 - extra_attention_mask = attention_mask > 0 - remove_from_windowed_attention_mask = attention_mask != 0 - - num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) - max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() - if max_num_extra_indices_per_batch <= 0: - extra_attention_mask = None - else: - # To support the case of variable number of global attention in the rows of a batch, - # we use the following three selection masks to select global attention embeddings - # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` - # 1) selecting embeddings that correspond to global attention - extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) - zero_to_max_range = torch.arange( - 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device - ) - # mask indicating which values are actually going to be padding - selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) - # 2) location of the non-padding values in the selected global attention - selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) - # 3) location of the padding values in the selected global attention - selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) - else: - remove_from_windowed_attention_mask = None - extra_attention_mask = None - key_padding_mask = None + attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) + + # is index masked or global attention + is_index_masked = attention_mask < 0 + is_index_global_attn = attention_mask > 0 + is_global_attn = any(is_index_global_attn.flatten()) hidden_states = hidden_states.transpose(0, 1) - seqlen, batch_size, embed_dim = hidden_states.size() - assert embed_dim == self.embed_dim - q = self.query(hidden_states) - k = self.key(hidden_states) - v = self.value(hidden_states) - q /= math.sqrt(self.head_dim) - - q = q.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - # attn_weights = (batch_size, seqlen, num_heads, window*2+1) - attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size) - if remove_from_windowed_attention_mask is not None: - # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 - # from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size) - remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( - dim=-1 - ) - # cast to fp32/fp16 then replace 1's with -inf - float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill( - remove_from_windowed_attention_mask, -10000.0 - ) - ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones - # diagonal mask with zeros everywhere and -inf inplace of padding - d_mask = self._sliding_chunks_matmul_qk(ones, float_mask, self.one_sided_attention_window_size) - attn_weights += d_mask - assert list(attn_weights.size()) == [ + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert ( + embed_dim == self.embed_dim + ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + + # normalize query + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) + + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, -10000.0 + ) + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + assert list(attn_scores.size()) == [ batch_size, - seqlen, + seq_len, self.num_heads, - self.one_sided_attention_window_size * 2 + 1, - ] + self.one_sided_attn_window_size * 2 + 1, + ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" + + # compute local attention probs from global attention keys and contact over window dim + if is_global_attn: + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) - # the extra attention - if extra_attention_mask is not None: - selected_k = k.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] - # (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch) - selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k)) - selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 - # concat to attn_weights - # (batch_size, seqlen, num_heads, extra attention count + 2*window+1) - attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) - - attn_weights_fp32 = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attn_weights = attn_weights_fp32.type_as(attn_weights) - - if key_padding_mask is not None: - # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attn_weights = torch.masked_fill(attn_weights, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) - - attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) - v = v.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - attn = None - if extra_attention_mask is not None: - selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) - selected_v = v.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) - selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] - # use `matmul` because `einsum` crashes sometimes with fp16 - # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2) - attn_probs = attn_probs.narrow( - -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch - ).contiguous() - if attn is None: - attn = self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size) - else: - attn += self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size) + # free memory + del global_key_attn_scores - assert attn.size() == (batch_size, seqlen, self.num_heads, self.head_dim), "Unexpected size" - attn = attn.transpose(0, 1).reshape(seqlen, batch_size, embed_dim).contiguous() + attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_probs = attn_probs_fp32.type_as(attn_scores) - # For this case, we'll just recompute the attention for these indices - # and overwrite the attn tensor. - # TODO: remove the redundant computation - if extra_attention_mask is not None: - selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, batch_size, embed_dim) - selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[ - extra_attention_mask_nonzeros[::-1] - ] + # free memory + del attn_probs_fp32 - q = self.query_global(selected_hidden_states) - k = self.key_global(hidden_states) - v = self.value_global(hidden_states) - q /= math.sqrt(self.head_dim) - - q = ( - q.contiguous() - .view(max_num_extra_indices_per_batch, batch_size * self.num_heads, self.head_dim) - .transpose(0, 1) - ) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim) - k = ( - k.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seqlen, head_dim) - v = ( - v.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) - ) # batch_size * self.num_heads, seqlen, head_dim) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen] - - attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen) - attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 - if key_padding_mask is not None: - attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) - attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen) - attn_weights_float = F.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ) # use fp32 for numerical stability - attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) - selected_attn = torch.bmm(attn_probs, v) - assert list(selected_attn.size()) == [ - batch_size * self.num_heads, - max_num_extra_indices_per_batch, - self.head_dim, - ] + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) + + # apply dropout + attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) + + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # compute local attention output with global attention value and add + if is_global_attn: + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() - selected_attn_4d = selected_attn.view( - batch_size, self.num_heads, max_num_extra_indices_per_batch, self.head_dim + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + if is_global_attn: + global_attn_output = self._compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, ) - nonzero_selected_attn = selected_attn_4d[ - selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1] + + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] ] - attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( - len(selection_padding_mask_nonzeros[0]), -1 + # overwrite values with global attention + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 ) - context_layer = attn.transpose(0, 1) + attn_output = attn_output.transpose(0, 1) + if output_attentions: - if extra_attention_mask is not None: + if is_global_attn: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, - # attn_weights are padded with -10000.0 attention scores - attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen) + # attn_probs are padded with -10000.0 attention scores + attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours - attn_weights = attn_weights.permute(0, 2, 1, 3) - outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) + attn_probs = attn_probs.permute(0, 2, 1, 3) + + outputs = (attn_output, attn_probs) if output_attentions else (attn_output,) + return outputs + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """ compute global attn indices required throughout forward pass """ + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # helper variable + is_local_index_global_attn = torch.arange( + max_num_global_attn_indices, device=is_index_global_attn.device + ) < num_global_attn_indices.unsqueeze(dim=-1) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = key_vectors.shape[0] + + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + attn_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + ] = -10000.0 + return attn_probs_from_global_key + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = attn_probs.shape[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) + # get value vectors for global only + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] + + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + # compute attn output only global + attn_output_only_global = torch.matmul( + attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + hidden_states, + max_num_global_attn_indices, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + ): + seq_len, batch_size = hidden_states.shape[:2] + + # prepare global hidden states + global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) + global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] + ] + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= math.sqrt(self.head_dim) + + # reshape + global_query_vectors_only_global = ( + global_query_vectors_only_global.contiguous() + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + + # compute attn scores + global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) + + assert list(global_attn_scores.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + seq_len, + ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}." + + global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + + global_attn_scores[ + is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : + ] = -10000.0 + + global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) + + global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + + # compute global attn probs + global_attn_probs_float = F.softmax( + global_attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + global_attn_probs = F.dropout( + global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training + ) + + # global attn output + global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) + + assert list(global_attn_output.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + self.head_dim, + ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}." + + global_attn_output = global_attn_output.view( + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim + ) + return global_attn_output + + +class LongformerAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.self = LongformerSelfAttention(config, layer_id) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, + ): + self_outputs = self.self(hidden_states, attention_mask, output_attentions,) + attn_output = self.output(self_outputs[0], hidden_states) + outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them return outputs +class LongformerLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.attention = LongformerAttention(config, layer_id) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, + ): + self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) + attn_output = self_attn_outputs[0] + outputs = self_attn_outputs[1:] # add self attentions if we output attention weights + + intermediate_output = self.intermediate(attn_output) + layer_output = self.output(intermediate_output, attn_output) + outputs = (layer_output,) + outputs + return outputs + + +class LongformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, attention_mask, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class LongformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained + models. + """ + + config_class = LongformerConfig + base_model_prefix = "longformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + LONGFORMER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `__ sub-class. @@ -498,7 +795,7 @@ def forward( "The bare Longformer Model outputting raw hidden-states without any specific head on top.", LONGFORMER_START_DOCSTRING, ) -class LongformerModel(RobertaModel): +class LongformerModel(LongformerPreTrainedModel): """ This class overrides :class:`~transformers.RobertaModel` to provide the ability to process long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer @@ -519,6 +816,7 @@ class LongformerModel(RobertaModel): def __init__(self, config): super().__init__(config) + self.config = config if isinstance(config.attention_window, int): assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" @@ -530,12 +828,26 @@ def __init__(self, config): f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" ) - for i, layer in enumerate(self.encoder.layer): - # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention` - layer.attention.self = LongformerSelfAttention(config, layer_id=i) + self.embeddings = RobertaEmbeddings(config) + self.encoder = LongformerEncoder(config) + self.pooler = BertPooler(config) self.init_weights() + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + def _pad_to_window_size( self, input_ids: torch.Tensor, @@ -543,30 +855,29 @@ def _pad_to_window_size( token_type_ids: torch.Tensor, position_ids: torch.Tensor, inputs_embeds: torch.Tensor, - attention_window: int, pad_token_id: int, ): """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + # padding + attention_window = ( + self.config.attention_window + if isinstance(self.config.attention_window, int) + else max(self.config.attention_window) + ) assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape - batch_size, seqlen = input_shape[:2] + batch_size, seq_len = input_shape[:2] - padding_len = (attention_window - seqlen % attention_window) % attention_window + padding_len = (attention_window - seq_len % attention_window) % attention_window if padding_len > 0: logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( - seqlen, seqlen + padding_len, attention_window + seq_len, seq_len + padding_len, attention_window ) ) if input_ids is not None: input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) - if attention_mask is not None: - attention_mask = F.pad( - attention_mask, (0, padding_len), value=False - ) # no attention on the padding tokens - if token_type_ids is not None: - token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 if position_ids is not None: # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) @@ -577,8 +888,23 @@ def _pad_to_window_size( inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens + token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, @@ -634,24 +960,25 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - # padding - attention_window = ( - self.config.attention_window - if isinstance(self.config.attention_window, int) - else max(self.config.attention_window) - ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # merge `global_attention_mask` and `attention_mask` if global_attention_mask is not None: - # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) - # (global_attention_mask + 1) => 1 for local attention, 2 for global attention - # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention - if attention_mask is not None: - attention_mask = attention_mask * (global_attention_mask + 1) - else: - # simply use `global_attention_mask` as `attention_mask` - # if no `attention_mask` is given - attention_mask = global_attention_mask + 1 + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( input_ids=input_ids, @@ -659,23 +986,29 @@ def forward( token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, - attention_window=attention_window, pad_token_id=self.config.pad_token_id, ) - # embed - output = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=None, - inputs_embeds=inputs_embeds, - encoder_hidden_states=None, - encoder_attention_mask=None, + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here # undo padding if padding_len > 0: @@ -684,13 +1017,13 @@ def forward( # `pooled_output`: independent of the sequence length # `hidden_states`: mainly used for debugging and analysis, so keep the padding # `attentions`: mainly used for debugging and analysis, so keep the padding - output = output[0][:, :-padding_len], *output[1:] + outputs = outputs[0][:, :-padding_len], *outputs[1:] - return output + return outputs @add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) -class LongformerForMaskedLM(BertPreTrainedModel): +class LongformerForMaskedLM(LongformerPreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4df1f75011ae..5e6705eaccaf 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -812,7 +812,7 @@ def test_multigpu_data_parallel_forward(self): # Wrap model in nn.DataParallel model = torch.nn.DataParallel(model) with torch.no_grad(): - _ = model(**inputs_dict) + _ = model(**self._prepare_for_class(inputs_dict, model_class)) global_rng = random.Random() diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 7579ee38ba0d..fed22060501c 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -115,6 +115,18 @@ def prepare_config_and_inputs(self): def check_loss_output(self, result): self.parent.assertListEqual(list(result["loss"].size()), []) + def create_and_check_attention_mask_determinism( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = LongformerModel(config=config) + model.to(torch_device) + model.eval() + + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + output_with_mask = model(input_ids, attention_mask=attention_mask)[0] + output_without_mask = model(input_ids)[0] + self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4)) + def create_and_check_longformer_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -134,6 +146,36 @@ def create_and_check_longformer_model( ) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_longformer_model_with_global_attention_mask( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = LongformerModel(config=config) + model.to(torch_device) + model.eval() + global_attention_mask = input_mask.clone() + global_attention_mask[:, input_mask.shape[-1] // 2] = 0 + global_attention_mask = global_attention_mask.to(torch_device) + + sequence_output, pooled_output = model( + input_ids, + attention_mask=input_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + ) + sequence_output, pooled_output = model( + input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask + ) + sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + def create_and_check_longformer_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -243,7 +285,13 @@ def prepare_config_and_inputs_for_common(self): token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + global_attention_mask = torch.zeros_like(input_ids) + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + "global_attention_mask": global_attention_mask, + } return config, inputs_dict def prepare_config_and_inputs_for_question_answering(self): @@ -277,11 +325,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ( LongformerModel, LongformerForMaskedLM, - # TODO: make tests pass for those models - # LongformerForSequenceClassification, - # LongformerForQuestionAnswering, - # LongformerForTokenClassification, - # LongformerForMultipleChoice, + LongformerForSequenceClassification, + LongformerForQuestionAnswering, + LongformerForTokenClassification, + LongformerForMultipleChoice, ) if is_torch_available() else () @@ -298,6 +345,14 @@ def test_longformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_model(*config_and_inputs) + def test_longformer_model_attention_mask_determinism(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) + + def test_longformer_model_global_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) + def test_longformer_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) @@ -325,15 +380,31 @@ def test_inference_no_head(self): model = LongformerModel.from_pretrained("allenai/longformer-base-4096") model.to(torch_device) + # 'Hello world!' + input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + output = model(input_ids, attention_mask=attention_mask)[0] + output_without_mask = model(input_ids)[0] + + expected_output_slice = torch.tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], device=torch_device) + self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4)) + self.assertTrue(torch.allclose(output_without_mask[0, 0, -5:], expected_output_slice, atol=1e-4)) + + @slow + def test_inference_no_head_long(self): + model = LongformerModel.from_pretrained("allenai/longformer-base-4096") + model.to(torch_device) + # 'Hello world! ' repeated 1000 times input_ids = torch.tensor( [[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device ) # long input attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) - attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions + global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device) + global_attention_mask[:, [1, 4, 21]] = 1 # Set global attention on a few random positions - output = model(input_ids, attention_mask=attention_mask)[0] + output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0] expected_output_sum = torch.tensor(74585.8594, device=torch_device) expected_output_mean = torch.tensor(0.0243, device=torch_device) @@ -341,7 +412,7 @@ def test_inference_no_head(self): self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4)) @slow - def test_inference_masked_lm(self): + def test_inference_masked_lm_long(self): model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096") model.to(torch_device) @@ -352,9 +423,9 @@ def test_inference_masked_lm(self): loss, prediction_scores = model(input_ids, labels=input_ids) - expected_loss = torch.tensor(0.0620, device=torch_device) - expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device) - expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device) + expected_loss = torch.tensor(0.0074, device=torch_device) + expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device) + expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device) input_ids = input_ids.to(torch_device) self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))