Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 101 additions & 16 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@
also be used by default. If you want to change padding behavior, you should read
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on
the default strategy
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.

decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
Expand Down Expand Up @@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0):


def _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
config,
input_ids,
decoder_input_ids=None,
decoder_padding_mask=None,
causal_mask_dtype=torch.float32,
):
"""
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
Expand Down Expand Up @@ -377,21 +392,27 @@ def __init__(self, config: FSMTConfig):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(self, x, encoder_padding_mask, output_attentions=False):
def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
x (:obj:`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (:obj:`torch.ByteTensor`): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.

Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
query=x,
key=x,
key_padding_mask=encoder_padding_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
Expand Down Expand Up @@ -432,21 +453,32 @@ def __init__(self, config: FSMTConfig, embed_tokens):
) # type: List[EncoderLayer]

def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True
self,
input_ids,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
"""
Args:
input_ids (LongTensor): tokens in the source language of shape
input_ids (:obj:`torch.LongTensor`): tokens in the source language of shape
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens
attention_mask (:obj:`torch.LongTensor`): indicating which indices are padding tokens
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.

Returns:
BaseModelOutput or Tuple comprised of:

- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len,
batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
- **x** (:obj:`torch.Tensor`): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (:obj:`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape
`(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
- **all_attentions** (:obj:`Tuple(torch.FloatTensor`)): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
Expand All @@ -463,7 +495,12 @@ def forward(

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
x = x.transpose(0, 1) # T x B x C -> B x T x C
encoder_states += (x,)
Expand All @@ -473,7 +510,12 @@ def forward(
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
x, attn = encoder_layer(
x,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)

if output_attentions:
all_attentions = all_attentions + (attn,)
Expand Down Expand Up @@ -522,6 +564,8 @@ def forward(
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
layer_head_mask=None,
encoder_layer_head_mask=None,
decoder_padding_mask=None,
output_attentions=False,
):
Expand All @@ -537,6 +581,7 @@ def forward(
layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
Expand All @@ -551,6 +596,7 @@ def forward(
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
layer_head_mask=encoder_layer_head_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
Expand Down Expand Up @@ -611,6 +657,8 @@ def forward(
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
Expand All @@ -622,12 +670,24 @@ def forward(
EMNLP 2019).

Args:
input_ids (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch, tgt_len)`):
previous decoder outputs for teacher forcing
encoder_hidden_states: output from the encoder, used for
encoder-side attention
encoder_padding_mask: for ignoring pad tokens
past_key_values (dict or None): dictionary used for storing state during generation
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.

encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.

Returns:
BaseModelOutputWithPast or tuple:
Expand Down Expand Up @@ -662,6 +722,12 @@ def forward(
all_self_attns = () if output_attentions else None
all_cross_attns = () if output_attentions else None
next_decoder_cache = []

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
Expand All @@ -681,6 +747,8 @@ def forward(
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -761,6 +829,7 @@ def forward(
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None,
layer_head_mask: Optional[Tensor] = None,
output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
Expand Down Expand Up @@ -830,6 +899,13 @@ def forward(

attn_weights = F.softmax(attn_weights, dim=-1)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if output_attentions:
# make sure that attn_weights are included in graph
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
Expand Down Expand Up @@ -923,6 +999,8 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Tuple] = None,
past_key_values=None,
use_cache=None,
Expand Down Expand Up @@ -958,6 +1036,7 @@ def forward(
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -977,6 +1056,8 @@ def forward(
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1052,6 +1133,8 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
labels=None,
Expand Down Expand Up @@ -1080,6 +1163,8 @@ def forward(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict(
config,
input_ids,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False

def setUp(self):
Expand Down