Skip to content

Conversation

@dblakely
Copy link
Contributor

@dblakely dblakely commented Jan 25, 2021

What does this PR do?

This PR makes RAG output the generator model's decoder cross-attentions when output_attentions=True.

Motivation and context: before this PR, RAG's output objects had attributes for the generator's encoder self-attentions and decoder self-attentions, but no option for the encoder-decoder cross-attentions. So this simply allows cross-attentions to be extracted, as well as fixing a small bug where output_attentions wasn't being passed into the generator.

Fixes #9468

Before submitting

Who can review?

@patrickvonplaten, @lhoestq

@dblakely dblakely changed the title Allow RAG to output decoder cross-attentions [WIP] Allow RAG to output decoder cross-attentions Jan 25, 2021
@dblakely dblakely changed the title [WIP] Allow RAG to output decoder cross-attentions Allow RAG to output decoder cross-attentions Jan 25, 2021
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Looks good to me :)

The suggestions below should help you fix the CI.
I think it just comes from a line length issue

@dblakely
Copy link
Contributor Author

@lhoestq Thanks for the suggestions! All the CI checks pass now.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice !
Also pinging @patrickvonplaten just to make sure

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super! Thanks for you contribution

@patrickvonplaten patrickvonplaten merged commit 8edc98b into huggingface:master Jan 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Have RAG return generator cross-attentions when output_attentions=True

3 participants