Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,16 @@ def generate(

if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
# see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)

assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

Expand Down
23 changes: 17 additions & 6 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def forward(
**kwargs_decoder,
)

# TODO(PVP): currently it is not possible to use `past`
# with the encoder/decoder framework -> should be implemented
return decoder_outputs + encoder_outputs

def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
Expand All @@ -299,15 +301,24 @@ def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwarg
encoder_outputs = (past,)

decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)

return {
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_inputs["attention_mask"],
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
}

# Ideally all models should have a `use_cache`
# leave following to ifs until all have it emplemented
if "use_cache" in decoder_inputs:
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]

if "past_key_values" in decoder_inputs:
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]

return input_dict

def _reorder_cache(self, past, beam_idx):
# as a default encoder-decoder models do not re-order the past.
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
return past
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
130 changes: 104 additions & 26 deletions src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):


class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
super().__init__()

n_state = nx # in Attention: n_state=768 (nx=n_embd)
Expand All @@ -131,8 +131,12 @@ def __init__(self, nx, n_ctx, config, scale=False):
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale

self.c_attn = Conv1D(n_state * 3, nx)
self.is_cross_attention = is_cross_attention
if self.is_cross_attention:
self.c_attn = Conv1D(2 * n_state, nx)
self.q_attn = Conv1D(n_state, nx)
else:
self.c_attn = Conv1D(3 * n_state, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
Expand Down Expand Up @@ -160,8 +164,11 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=
if self.scale:
w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -193,10 +200,26 @@ def split_heads(self, x, k=False):
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
if encoder_hidden_states is not None:
assert hasattr(
self, "q_attn"
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
Expand Down Expand Up @@ -239,32 +262,64 @@ def forward(self, x):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super().__init__()
nx = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
hidden_size = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(hidden_size, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config)

def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
output_attn = self.attn(
self.ln_1(x),
attn_outputs = self.attn(
self.ln_1(hidden_states),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
a = output_attn[0] # output_attn: a, present, (attentions)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + hidden_states

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attn_outputs = self.crossattention(
self.ln_cross_attn(hidden_states),
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights

x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states

outputs = [x] + output_attn[1:]
return outputs # x, present, (attentions)
outputs = [hidden_states] + outputs
return outputs # hidden_states, present, (cross_attentions, attentions)


class GPT2PreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -449,6 +504,8 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -506,7 +563,7 @@ def forward(
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask[:, None, None, :]

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
Expand All @@ -516,6 +573,17 @@ def forward(
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0

# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
Expand Down Expand Up @@ -546,6 +614,8 @@ def forward(
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -593,17 +663,21 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.lm_head

def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs["use_cache"] if "use_cache" in kwargs else None,
}

@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
checkpoint="gpt2",
output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
Expand All @@ -616,6 +690,8 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
Expand Down Expand Up @@ -648,6 +724,8 @@ def forward(
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down
60 changes: 58 additions & 2 deletions tests/test_modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device

# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor
from .test_modeling_gpt2 import GPT2ModelTester
from .test_modeling_roberta import RobertaModelTester


if is_torch_available():
from transformers import (
BertModel,
BertLMHeadModel,
GPT2LMHeadModel,
RobertaModel,
RobertaForCausalLM,
EncoderDecoderModel,
Expand Down Expand Up @@ -424,3 +424,59 @@ def prepare_config_and_inputs(self):

def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")


class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = GPT2LMHeadModel(decoder_config)
return encoder_model, decoder_model

def prepare_config_and_inputs(self):
model_tester_encoder = BertModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_input_mask,
decoder_head_mask,
decoder_token_type_ids,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs

# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}

def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
Loading