diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index c731c01fa983..891a09ff1464 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -372,11 +372,16 @@ def generate( if self.config.is_encoder_decoder: if decoder_start_token_id is None: - decoder_start_token_id = bos_token_id + # see if BOS token can be used for decoder_start_token_id + if bos_token_id is not None: + decoder_start_token_id = bos_token_id + elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): + decoder_start_token_id = self.config.decoder.bos_token_id + else: + raise ValueError( + "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + ) - assert ( - decoder_start_token_id is not None - ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 0ee5f2962b6d..664b4181f5b7 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -287,6 +287,8 @@ def forward( **kwargs_decoder, ) + # TODO(PVP): currently it is not possible to use `past` + # with the encoder/decoder framework -> should be implemented return decoder_outputs + encoder_outputs def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): @@ -299,15 +301,24 @@ def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwarg encoder_outputs = (past,) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) - - return { + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { "attention_mask": attention_mask, - "decoder_attention_mask": decoder_inputs["attention_mask"], + "decoder_attention_mask": decoder_attention_mask, "decoder_input_ids": decoder_inputs["input_ids"], "encoder_outputs": encoder_outputs, } + # Ideally all models should have a `use_cache` + # leave following to ifs until all have it implemented + if "use_cache" in decoder_inputs: + input_dict["decoder_use_cache"] = decoder_inputs["use_cache"] + + if "past_key_values" in decoder_inputs: + input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"] + + return input_dict + def _reorder_cache(self, past, beam_idx): - # as a default encoder-decoder models do not re-order the past. - # TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder - return past + # apply decoder cache reordering here + return self.decoder._reorder_cache(past, beam_idx) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index ea23a819d547..7f44f8a5ac3e 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): class Attention(nn.Module): - def __init__(self, nx, n_ctx, config, scale=False): + def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): super().__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) @@ -131,8 +131,12 @@ def __init__(self, nx, n_ctx, config, scale=False): self.n_head = config.n_head self.split_size = n_state self.scale = scale - - self.c_attn = Conv1D(n_state * 3, nx) + self.is_cross_attention = is_cross_attention + if self.is_cross_attention: + self.c_attn = Conv1D(2 * n_state, nx) + self.q_attn = Conv1D(n_state, nx) + else: + self.c_attn = Conv1D(3 * n_state, nx) self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -160,8 +164,11 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= if self.scale: w = w / (float(v.size(-1)) ** 0.5) nd, ns = w.size(-2), w.size(-1) - mask = self.bias[:, :, ns - nd : ns, :ns] - w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + mask = self.bias[:, :, ns - nd : ns, :ns] + w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) if attention_mask is not None: # Apply the attention mask @@ -193,10 +200,26 @@ def split_heads(self, x, k=False): return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def forward( - self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, ): - x = self.c_attn(x) - query, key, value = x.split(self.split_size, dim=2) + if encoder_hidden_states is not None: + assert hasattr( + self, "q_attn" + ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) @@ -239,32 +262,64 @@ def forward(self, x): class Block(nn.Module): def __init__(self, n_ctx, config, scale=False): super().__init__() - nx = config.n_embd - inner_dim = config.n_inner if config.n_inner is not None else 4 * nx - self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) - self.attn = Attention(nx, n_ctx, config, scale) - self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + hidden_size = config.n_embd + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = Attention(hidden_size, n_ctx, config, scale) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if config.add_cross_attention: + self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = MLP(inner_dim, config) def forward( - self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, ): - output_attn = self.attn( - self.ln_1(x), + attn_outputs = self.attn( + self.ln_1(hidden_states), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) - a = output_attn[0] # output_attn: a, present, (attentions) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + hidden_states + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + assert hasattr( + self, "crossattention" + ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + cross_attn_outputs = self.crossattention( + self.ln_cross_attn(hidden_states), + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = hidden_states + attn_output + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights - x = x + a - m = self.mlp(self.ln_2(x)) - x = x + m + feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) + # residual connection + hidden_states = hidden_states + feed_forward_hidden_states - outputs = [x] + output_attn[1:] - return outputs # x, present, (attentions) + outputs = [hidden_states] + outputs + return outputs # hidden_states, present, (cross_attentions, attentions) class GPT2PreTrainedModel(PreTrainedModel): @@ -449,6 +504,8 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -506,7 +563,7 @@ def forward( # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -516,6 +573,17 @@ def forward( attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -546,6 +614,8 @@ def forward( layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) @@ -593,17 +663,21 @@ def __init__(self, config): def get_output_embeddings(self): return self.lm_head - def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) - return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]} + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + } @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="ctrl", + checkpoint="gpt2", output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC, ) @@ -616,6 +690,8 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, labels=None, use_cache=None, output_attentions=None, @@ -648,6 +724,8 @@ def forward( position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 5bc21ca0c40d..e56a04369bd9 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -20,10 +20,9 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device -# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented -# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest from .test_modeling_bert import BertModelTester from .test_modeling_common import ids_tensor +from .test_modeling_gpt2 import GPT2ModelTester from .test_modeling_roberta import RobertaModelTester @@ -31,6 +30,7 @@ from transformers import ( BertModel, BertLMHeadModel, + GPT2LMHeadModel, RobertaModel, RobertaForCausalLM, EncoderDecoderModel, @@ -424,3 +424,59 @@ def prepare_config_and_inputs(self): def get_pretrained_model(self): return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base") + + +class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = BertModel(config) + decoder_model = GPT2LMHeadModel(decoder_config) + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = BertModelTester(self, batch_size=13) + model_tester_decoder = GPT2ModelTester(self, batch_size=13) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_input_mask, + decoder_head_mask, + decoder_token_type_ids, + decoder_sequence_labels, + decoder_token_labels, + decoder_choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_token_type_ids": decoder_token_type_ids, + "decoder_attention_mask": decoder_input_mask, + "decoder_sequence_labels": decoder_sequence_labels, + "decoder_token_labels": decoder_token_labels, + "decoder_choice_labels": decoder_choice_labels, + "encoder_hidden_states": encoder_hidden_states, + "labels": decoder_token_labels, + } + + def get_pretrained_model(self): + return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 66e07f6d4a70..f483467a4c1e 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -20,7 +20,7 @@ from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester -from .test_modeling_common import ModelTesterMixin, ids_tensor +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): @@ -62,27 +62,27 @@ def __init__( scope=None, ): self.parent = parent - self.batch_size = 14 - self.seq_length = 7 - self.is_training = True - self.use_token_type_ids = True - self.use_input_mask = True - self.use_labels = True - self.use_mc_token_ids = True - self.vocab_size = 99 - self.hidden_size = 32 - self.num_hidden_layers = 5 - self.num_attention_heads = 4 - self.intermediate_size = 37 - self.hidden_act = "gelu" - self.hidden_dropout_prob = 0.1 - self.attention_probs_dropout_prob = 0, 1 - self.max_position_embeddings = 512 - self.type_vocab_size = 16 - self.type_sequence_label_size = 2 - self.initializer_range = 0.02 - self.num_labels = 3 - self.num_choices = 4 + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices self.scope = None self.bos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1 @@ -142,6 +142,35 @@ def prepare_config_and_inputs(self): choice_labels, ) + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = GPT2Model(config=config) model.to(torch_device)