Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutput,
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
Expand Down Expand Up @@ -449,7 +449,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -483,15 +484,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -752,7 +762,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPooling,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -843,11 +853,12 @@ def forward(
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -984,7 +995,7 @@ def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1063,11 +1074,12 @@ def forward(
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/modeling_bert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
replace_return_docstrings,
)
from .modeling_bert import BertEncoder
from .modeling_outputs import BaseModelOutput, CausalLMOutput
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
from .modeling_utils import PreTrainedModel
from .utils import logging

Expand Down Expand Up @@ -297,7 +297,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
output_type=BaseModelOutput,
output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -381,10 +381,11 @@ def forward(
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]

return BaseModelOutput(
return BaseModelOutputWithCrossAttentions(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -422,7 +423,7 @@ def get_output_embeddings(self):
return self.lm_head.decoder

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -499,11 +500,12 @@ def forward(
output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
24 changes: 17 additions & 7 deletions src/transformers/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Expand Down Expand Up @@ -445,7 +445,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -479,15 +480,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -697,7 +707,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=BaseModelOutput,
output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def forward(
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
Expand Down
31 changes: 20 additions & 11 deletions src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from .modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPastAndCrossAttentions,
SequenceClassifierOutputWithPast,
)
from .modeling_utils import (
Conv1D,
PreTrainedModel,
Expand Down Expand Up @@ -311,14 +315,14 @@ def forward(
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch!


feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states

outputs = [hidden_states] + outputs
return outputs # hidden_states, present, (cross_attentions, attentions)
return outputs # hidden_states, present, (attentions, cross_attentions)


class GPT2PreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -506,7 +510,7 @@ def _prune_heads(self, heads_to_prune):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=BaseModelOutputWithPast,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -618,7 +622,8 @@ def forward(
output_shape = input_shape + (hidden_states.size(-1),)

presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
Expand Down Expand Up @@ -659,7 +664,9 @@ def custom_forward(*inputs):
presents = presents + (present,)

if output_attentions:
all_attentions = all_attentions + (outputs[2],)
all_self_attentions = all_self_attentions + (outputs[2],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3],)

hidden_states = self.ln_f(hidden_states)

Expand All @@ -669,13 +676,14 @@ def custom_forward(*inputs):
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPast(
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -727,7 +735,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=CausalLMOutputWithPast,
output_type=CausalLMOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -795,12 +803,13 @@ def forward(
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
return CausalLMOutputWithPastAndCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)


Expand Down
32 changes: 24 additions & 8 deletions src/transformers/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from .activations import ACT2FN
from .configuration_layoutlm import LayoutLMConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, TokenClassifierOutput
from .modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
Expand Down Expand Up @@ -374,7 +379,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -408,15 +414,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -611,7 +626,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=BaseModelOutputWithPooling,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -716,11 +731,12 @@ def forward(
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down
Loading