Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput):

Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.

Cross-attentions weights of the generator decoder, after the attention softmax, used to compute
the weighted average in the cross-attention heads.
"""

loss: Optional[torch.FloatTensor] = None
Expand All @@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
Expand Down Expand Up @@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput):

Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.

Cross-attentions weights of the generator decoder, after the attention softmax, used to compute
the weighted average in the cross-attention heads.
"""

logits: torch.FloatTensor = None
Expand All @@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


class RagPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -619,6 +633,7 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=True,
)

Expand Down Expand Up @@ -655,6 +670,7 @@ def forward(
generator_enc_attentions=gen_outputs.encoder_attentions,
generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
generator_dec_attentions=gen_outputs.decoder_attentions,
generator_cross_attentions=gen_outputs.cross_attentions,
)


Expand Down Expand Up @@ -803,6 +819,7 @@ def forward(
generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
)

@property
Expand Down Expand Up @@ -1264,6 +1281,7 @@ def forward(
generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
)

@torch.no_grad()
Expand Down