diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 2b733cd1b6bd..a57598d84397 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput): Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the + weighted average in the cross-attention heads. """ loss: Optional[torch.FloatTensor] = None @@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput): generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None + generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass @@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput): Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the + weighted average in the cross-attention heads. """ logits: torch.FloatTensor = None @@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput): generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None + generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None class RagPreTrainedModel(PreTrainedModel): @@ -619,6 +633,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, + output_attentions=output_attentions, return_dict=True, ) @@ -655,6 +670,7 @@ def forward( generator_enc_attentions=gen_outputs.encoder_attentions, generator_dec_hidden_states=gen_outputs.decoder_hidden_states, generator_dec_attentions=gen_outputs.decoder_attentions, + generator_cross_attentions=gen_outputs.cross_attentions, ) @@ -803,6 +819,7 @@ def forward( generator_enc_attentions=outputs.generator_enc_attentions, generator_dec_hidden_states=outputs.generator_dec_hidden_states, generator_dec_attentions=outputs.generator_dec_attentions, + generator_cross_attentions=outputs.generator_cross_attentions, ) @property @@ -1264,6 +1281,7 @@ def forward( generator_enc_attentions=outputs.generator_enc_attentions, generator_dec_hidden_states=outputs.generator_dec_hidden_states, generator_dec_attentions=outputs.generator_dec_attentions, + generator_cross_attentions=outputs.generator_cross_attentions, ) @torch.no_grad()