-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Description
🚀 Have RAG return generator cross-attentions when output_attentions=True
This feature request is for the RAG code to be modified so that if output_attentions=True, it returns the generator's cross-attentions in addition to the attentions it already returns.
Motivation
I'm interested in extracting the generator's attentions from a RAG generator model. Currently, transformers allows you to extract the generator's encoder attentions and decoder attentions, but not it's cross attentions. For example, inside modeling_rag.py, the return objects such as RetrievAugLMMarginOutput, have fields for these other attentions, but not the cross-attentions.
Because both T5 and BART can output cross-attentions, I think they could simply propagate up through the RAG code. Is there a reason this isn't already the case? Or could I do a PR to include the cross attentions along with the other attentions in the model output?
Your contribution
On my own fork of transformers, I've already added this feature and would happily submit a PR!