-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs #8071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
patrickvonplaten
merged 8 commits into
huggingface:master
from
ysgit:output-crossattention
Nov 6, 2020
Merged
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
925092d
Output cross-attention with decoder attention output
4c182fa
Update src/transformers/modeling_bert.py
patrickvonplaten bd7eb88
Merge remote-tracking branch 'main/master' into output-crossattention
patrickvonplaten 8f8575f
add cross-attention for t5 and bart as well
patrickvonplaten d9a5e35
fix tests
patrickvonplaten 518acf6
correct typo in docs
patrickvonplaten 9225cab
add sylvains and sams comments
patrickvonplaten 0205b70
correct typo
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,11 @@ | |
| add_start_docstrings_to_model_forward, | ||
| replace_return_docstrings, | ||
| ) | ||
| from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | ||
| from .modeling_outputs import ( | ||
| BaseModelOutputWithPastAndCrossAttentions, | ||
| CausalLMOutputWithPastAndCrossAttentions, | ||
| SequenceClassifierOutputWithPast, | ||
| ) | ||
| from .modeling_utils import ( | ||
| Conv1D, | ||
| PreTrainedModel, | ||
|
|
@@ -311,14 +315,14 @@ def forward( | |
| attn_output = cross_attn_outputs[0] | ||
| # residual connection | ||
| hidden_states = hidden_states + attn_output | ||
| outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights | ||
| outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great catch! |
||
|
|
||
| feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) | ||
| # residual connection | ||
| hidden_states = hidden_states + feed_forward_hidden_states | ||
|
|
||
| outputs = [hidden_states] + outputs | ||
| return outputs # hidden_states, present, (cross_attentions, attentions) | ||
| return outputs # hidden_states, present, (attentions, cross_attentions) | ||
|
|
||
|
|
||
| class GPT2PreTrainedModel(PreTrainedModel): | ||
|
|
@@ -506,7 +510,7 @@ def _prune_heads(self, heads_to_prune): | |
| @add_code_sample_docstrings( | ||
| tokenizer_class=_TOKENIZER_FOR_DOC, | ||
| checkpoint="gpt2", | ||
| output_type=BaseModelOutputWithPast, | ||
| output_type=BaseModelOutputWithPastAndCrossAttentions, | ||
| config_class=_CONFIG_FOR_DOC, | ||
| ) | ||
| def forward( | ||
|
|
@@ -618,7 +622,8 @@ def forward( | |
| output_shape = input_shape + (hidden_states.size(-1),) | ||
|
|
||
| presents = () if use_cache else None | ||
| all_attentions = () if output_attentions else None | ||
| all_self_attentions = () if output_attentions else None | ||
| all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None | ||
| all_hidden_states = () if output_hidden_states else None | ||
| for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): | ||
| if output_hidden_states: | ||
|
|
@@ -659,7 +664,9 @@ def custom_forward(*inputs): | |
| presents = presents + (present,) | ||
|
|
||
| if output_attentions: | ||
| all_attentions = all_attentions + (outputs[2],) | ||
| all_self_attentions = all_self_attentions + (outputs[2],) | ||
| if self.config.add_cross_attention: | ||
| all_cross_attentions = all_cross_attentions + (outputs[3],) | ||
|
|
||
| hidden_states = self.ln_f(hidden_states) | ||
|
|
||
|
|
@@ -669,13 +676,14 @@ def custom_forward(*inputs): | |
| all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
|
||
| if not return_dict: | ||
| return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) | ||
| return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) | ||
|
|
||
| return BaseModelOutputWithPast( | ||
| return BaseModelOutputWithPastAndCrossAttentions( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=presents, | ||
| hidden_states=all_hidden_states, | ||
| attentions=all_attentions, | ||
| attentions=all_self_attentions, | ||
| cross_attentions=all_cross_attentions, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -727,7 +735,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |
| @add_code_sample_docstrings( | ||
| tokenizer_class=_TOKENIZER_FOR_DOC, | ||
| checkpoint="gpt2", | ||
| output_type=CausalLMOutputWithPast, | ||
| output_type=CausalLMOutputWithPastAndCrossAttentions, | ||
| config_class=_CONFIG_FOR_DOC, | ||
| ) | ||
| def forward( | ||
|
|
@@ -795,12 +803,13 @@ def forward( | |
| output = (lm_logits,) + transformer_outputs[1:] | ||
| return ((loss,) + output) if loss is not None else output | ||
|
|
||
| return CausalLMOutputWithPast( | ||
| return CausalLMOutputWithPastAndCrossAttentions( | ||
| loss=loss, | ||
| logits=lm_logits, | ||
| past_key_values=transformer_outputs.past_key_values, | ||
| hidden_states=transformer_outputs.hidden_states, | ||
| attentions=transformer_outputs.attentions, | ||
| cross_attentions=transformer_outputs.cross_attentions, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.