Skip to content

Can't get (global) attention probs using Longformer #5646

@k141303

Description

@k141303

🐛 Bug

Information

Model I am using Longformer:

Language I am using the model on Japanese:

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Set config.output_attentions=True
  2. Use global attention (sum(global_attention_mask)>0)

The following is the minimum code to reproduce the error.

import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig

if __name__ == '__main__':
    config = AutoConfig.from_pretrained("allenai/longformer-base-4096", output_attentions=True)
    model = AutoModel.from_pretrained("allenai/longformer-base-4096", config=config)
    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    token_ids = [[
        tokenizer.cls_token_id, 10, 11, 12,
        tokenizer.sep_token_id, 21, 22, 23,
        tokenizer.sep_token_id
    ]]
    global_attention_mask = [[1,1,1,1,1,0,0,0,0]]
    logit, *_, attention_probs = model(
        torch.LongTensor(token_ids),
        global_attention_mask=torch.LongTensor(global_attention_mask)
    )
    print(attention_probs[0].size())
$ python3 test.py
Traceback (most recent call last):
  File "test_longformer.py", line 16, in <module>
    global_attention_mask=torch.LongTensor(global_attention_mask)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 1004, in forward
    output_hidden_states=output_hidden_states,
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 695, in forward
    layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 658, in forward
    self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 642, in forward
    self_outputs = self.self(hidden_states, attention_mask, output_attentions,)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 435, in forward
    attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
RuntimeError: shape '[1, 12, 5, 512]' is invalid for input of size 3182592

Expected behavior

The model can output attention probs for each attention head.

$ python3 test.py
torch.Size([1, 12, 4096, 5])

It would seem to work if I rewrite the target line as follows.

attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)

#attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
attn_probs = attn_probs[:,:,:,:max_num_global_attn_indices]
attn_probs = attn_probs.permute(0, 2, 1, 3)

Environment info

  • transformers version:3.0.2
  • Platform:Ubuntu 18.04.4 LTS
  • Python version:Python 3.6.9 :: Anaconda, Inc.
  • PyTorch version (GPU?):1.5.1 (Yes)
  • Tensorflow version (GPU?):
  • Using GPU in script?:Yes
  • Using distributed or parallel set-up in script?:Yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions