Skip to content

Have RAG return generator cross-attentions when output_attentions=True #9468

@dblakely

Description

@dblakely

🚀 Have RAG return generator cross-attentions when output_attentions=True

This feature request is for the RAG code to be modified so that if output_attentions=True, it returns the generator's cross-attentions in addition to the attentions it already returns.

Motivation

I'm interested in extracting the generator's attentions from a RAG generator model. Currently, transformers allows you to extract the generator's encoder attentions and decoder attentions, but not it's cross attentions. For example, inside modeling_rag.py, the return objects such as RetrievAugLMMarginOutput, have fields for these other attentions, but not the cross-attentions.

Because both T5 and BART can output cross-attentions, I think they could simply propagate up through the RAG code. Is there a reason this isn't already the case? Or could I do a PR to include the cross attentions along with the other attentions in the model output?

Your contribution

On my own fork of transformers, I've already added this feature and would happily submit a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions