diff --git a/docs/source/model_doc/blenderbot.rst b/docs/source/model_doc/blenderbot.rst index 0a926b6250a8..4d79144e8e44 100644 --- a/docs/source/model_doc/blenderbot.rst +++ b/docs/source/model_doc/blenderbot.rst @@ -95,3 +95,12 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` .. autoclass:: transformers.BlenderbotForConditionalGeneration :members: + + +TFBlenderbotForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate` + +.. autoclass:: transformers.TFBlenderbotForConditionalGeneration + :members: diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index 8fc845d2f6a0..a5149ca7f0e2 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -129,3 +129,9 @@ MarianMTModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MarianMTModel + + +TFMarianMTModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMarianMTModel diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index 5df5645f8fc8..1d3a43968e41 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -79,4 +79,11 @@ MBartForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MBartForConditionalGeneration - :members: forward + :members: + + +TFMBartForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMBartForConditionalGeneration + :members: diff --git a/docs/source/model_doc/pegasus.rst b/docs/source/model_doc/pegasus.rst index 4638e8470c44..5e885d553c7d 100644 --- a/docs/source/model_doc/pegasus.rst +++ b/docs/source/model_doc/pegasus.rst @@ -95,3 +95,9 @@ PegasusForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.PegasusForConditionalGeneration + + +TFPegasusForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFPegasusForConditionalGeneration diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3ed23ff3001c..6662b2011e26 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -670,6 +670,7 @@ TFBertModel, TFBertPreTrainedModel, ) + from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration from .modeling_tf_camembert import ( TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFCamembertForMaskedLM, @@ -750,6 +751,8 @@ TFLxmertPreTrainedModel, TFLxmertVisualFeatureEncoder, ) + from .modeling_tf_marian import TFMarianMTModel + from .modeling_tf_mbart import TFMBartForConditionalGeneration from .modeling_tf_mobilebert import ( TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFMobileBertForMaskedLM, @@ -771,6 +774,7 @@ TFOpenAIGPTModel, TFOpenAIGPTPreTrainedModel, ) + from .modeling_tf_pegasus import TFPegasusForConditionalGeneration from .modeling_tf_roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TFRobertaForMaskedLM, diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 78a840cea557..da34f70072c4 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -427,7 +427,6 @@ def forward( output_attentions=False, ): residual = x - if layer_state is None: layer_state = {} if self.normalize_before: @@ -447,7 +446,7 @@ def forward( if not self.normalize_before: x = self.self_attn_layer_norm(x) - # Cross attention + # Cross-Attention Block residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key if self.normalize_before: @@ -628,7 +627,6 @@ def forward( encoder_hidden_states = encoder_hidden_states.transpose(0, 1) next_cache = next_decoder_cache if use_cache else None - if not return_dict: return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index c90028125c2b..e8a51454d1c1 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -41,6 +41,10 @@ XLNetConfig, replace_list_option_in_docstrings, ) +from .configuration_blenderbot import BlenderbotConfig +from .configuration_marian import MarianConfig +from .configuration_mbart import MBartConfig +from .configuration_pegasus import PegasusConfig from .configuration_utils import PretrainedConfig from .file_utils import add_start_docstrings from .modeling_tf_albert import ( @@ -63,6 +67,7 @@ TFBertLMHeadModel, TFBertModel, ) +from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration from .modeling_tf_camembert import ( TFCamembertForMaskedLM, TFCamembertForMultipleChoice, @@ -108,6 +113,8 @@ ) from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel +from .modeling_tf_marian import TFMarianMTModel +from .modeling_tf_mbart import TFMBartForConditionalGeneration from .modeling_tf_mobilebert import ( TFMobileBertForMaskedLM, TFMobileBertForMultipleChoice, @@ -118,6 +125,7 @@ TFMobileBertModel, ) from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel +from .modeling_tf_pegasus import TFPegasusForConditionalGeneration from .modeling_tf_roberta import ( TFRobertaForMaskedLM, TFRobertaForMultipleChoice, @@ -210,6 +218,7 @@ (T5Config, TFT5ForConditionalGeneration), (DistilBertConfig, TFDistilBertForMaskedLM), (AlbertConfig, TFAlbertForMaskedLM), + (MarianConfig, TFMarianMTModel), (BartConfig, TFBartForConditionalGeneration), (CamembertConfig, TFCamembertForMaskedLM), (XLMRobertaConfig, TFXLMRobertaForMaskedLM), @@ -261,8 +270,16 @@ ] ) + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( - [(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)] + [ + (T5Config, TFT5ForConditionalGeneration), + (MarianConfig, TFMarianMTModel), + (MBartConfig, TFMBartForConditionalGeneration), + (PegasusConfig, TFPegasusForConditionalGeneration), + (BlenderbotConfig, TFBlenderbotForConditionalGeneration), + (BartConfig, TFBartForConditionalGeneration), + ] ) TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( diff --git a/src/transformers/modeling_tf_bart.py b/src/transformers/modeling_tf_bart.py index a2e05cf971df..6814520a6342 100644 --- a/src/transformers/modeling_tf_bart.py +++ b/src/transformers/modeling_tf_bart.py @@ -19,9 +19,10 @@ import warnings from typing import Dict, Optional, Tuple +import numpy as np import tensorflow as tf from tensorflow import Tensor -from tensorflow.keras.layers import Dense, LayerNormalization +from tensorflow.keras.layers import Dense, Layer, LayerNormalization from .activations_tf import ACT2FN from .configuration_bart import BartConfig @@ -43,7 +44,6 @@ _CONFIG_FOR_DOC = "BartConfig" -_TOKENIZER_FOR_DOC = "BartTokenizer" BART_START_DOCSTRING = r""" @@ -218,22 +218,21 @@ def make_padding_mask(input_ids, padding_idx=1): ) -class TFEncoderLayer(tf.keras.layers.Layer): +class TFEncoderLayer(Layer): def __init__(self, config: BartConfig, **kwargs): super().__init__(**kwargs) self.embed_dim = config.d_model self.self_attn = TFAttention( self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" ) - - self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout_wt = tf.keras.layers.Dropout(config.dropout) + self.normalize_before = config.normalize_before + self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout) + self.activation_dropout = config.activation_dropout self.fc1 = Dense(config.encoder_ffn_dim, name="fc1") self.fc2 = Dense(self.embed_dim, name="fc2") self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.normalize_before = config.normalize_before def call(self, x, encoder_padding_mask, training=False): """ @@ -251,8 +250,10 @@ def call(self, x, encoder_padding_mask, training=False): if self.normalize_before: x = self.self_attn_layer_norm(x) x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask) - assert x.shape == residual.shape, f"Self attn modified the shape of query {residual.shape} to {x.shape}" - x = self.dropout_wt(x, training=training) + assert shape_list(x) == shape_list( + residual + ), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}" + x = tf.nn.dropout(x, rate=self.dropout if training else 0) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -261,9 +262,9 @@ def call(self, x, encoder_padding_mask, training=False): if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout(x, training=training) + x = tf.nn.dropout(x, rate=self.self.activation_dropout if training else 0) x = self.fc2(x) - x = self.dropout_wt(x, training=training) + x = tf.nn.dropout(x, rate=self.dropout if training else 0) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) @@ -271,7 +272,7 @@ def call(self, x, encoder_padding_mask, training=False): return x, self_attn_weights -class TFBartEncoder(tf.keras.layers.Layer): +class TFBartEncoder(Layer): # config_class = BartConfig """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -289,26 +290,30 @@ def __init__(self, config: BartConfig, embed_tokens: TFSharedEmbeddings, **kwarg self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions - embed_dim = embed_tokens.vocab_size - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings self.embed_tokens = embed_tokens - self.embed_positions = TFLearnedPositionalEmbedding( - config.max_position_embeddings, - embed_tokens.hidden_size, - self.padding_idx, - config.extra_pos_embeddings, - name="embed_positions", - ) + if config.static_position_embeddings: + self.embed_positions = TFSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + else: + self.embed_positions = TFLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + config.extra_pos_embeddings, + name="embed_positions", + ) self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - self.layer_norm = ( - tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - if config.add_final_layer_norm - else None + self.layernorm_embedding = ( + LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer() ) + self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None self.return_dict = config.return_dict def call( @@ -347,7 +352,7 @@ def call( ), f"expected attention_mask._rank() to be a 2D tensor got {attention_mask._rank()}" attention_mask = tf.cast(attention_mask, dtype=tf.float32) attention_mask = (1.0 - attention_mask) * LARGE_NEGATIVE - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_ids) x = inputs_embeds + embed_pos x = self.layernorm_embedding(x) @@ -384,7 +389,7 @@ def call( return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) -class TFDecoderLayer(tf.keras.layers.Layer): +class TFDecoderLayer(Layer): def __init__(self, config: BartConfig, **kwargs): super().__init__(**kwargs) self.embed_dim = config.d_model @@ -397,8 +402,9 @@ def __init__(self, config: BartConfig, **kwargs): self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout + self.normalize_before = config.normalize_before - self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") self.encoder_attn = TFAttention( self.embed_dim, config.decoder_attention_heads, @@ -406,10 +412,10 @@ def __init__(self, config: BartConfig, **kwargs): encoder_decoder_attention=True, name="encoder_attn", ) - self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.encoder_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") self.fc1 = Dense(config.decoder_ffn_dim, name="fc1") self.fc2 = Dense(self.embed_dim, name="fc2") - self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm") def call( self, @@ -433,10 +439,12 @@ def call( Tuple containing, encoded output of shape `(seq_len, batch, embed_dim)`, self_attn_weights, layer_state """ + residual = x # Make a copy of the input tensor to add later. if layer_state is None: layer_state = {} + if self.normalize_before: + x = self.self_attn_layer_norm(x) - residual = x # Make a copy of the input tensor to add later. # next line mutates layer state and we need a copy of it x, self_attn_weights = self.self_attn( query=x, @@ -447,9 +455,12 @@ def call( ) x = tf.nn.dropout(x, rate=self.dropout if training else 0) x = residual + x - x = self.self_attn_layer_norm(x) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + # Cross-Attention Block residual = x - # Cross-Attention + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) x, _ = self.encoder_attn( query=x, key=encoder_hidden_states, @@ -458,16 +469,19 @@ def call( ) x = tf.nn.dropout(x, rate=self.dropout if training else 0) x = residual + x - - x = self.encoder_attn_layer_norm(x) - + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + # Fully Connected residual = x + if self.normalize_before: + x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = tf.nn.dropout(x, rate=self.activation_dropout if training else 0) x = self.fc2(x) x = tf.nn.dropout(x, rate=self.dropout if training else 0) x = residual + x - x = self.final_layer_norm(x) + if not self.normalize_before: + x = self.final_layer_norm(x) return ( x, self_attn_weights, @@ -475,7 +489,7 @@ def call( ) # just self_attn weights for now, following t5, layer_state = cache for decoding -class TFBartDecoder(tf.keras.layers.Layer): +class TFBartDecoder(Layer): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer` @@ -491,26 +505,27 @@ def __init__(self, config: BartConfig, embed_tokens, **kwargs): self.max_target_positions = config.max_position_embeddings self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_positions = TFLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - self.padding_idx, - config.extra_pos_embeddings, - name="embed_positions", - ) + if config.static_position_embeddings: + self.embed_positions = TFSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + else: + self.embed_positions = TFLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + config.extra_pos_embeddings, + name="embed_positions", + ) self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] self.layernorm_embedding = ( - tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - if config.normalize_embedding - else tf.identity - ) - self.layer_norm = ( - tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - if config.add_final_layer_norm - else None + LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer() ) + self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None - self.dropout = tf.keras.layers.Dropout(config.dropout) + self.dropout = config.dropout self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.use_cache = config.use_cache @@ -553,11 +568,11 @@ def call( x = self.layernorm_embedding(x) + positions else: x = self.layernorm_embedding(x + positions) - x = self.dropout(x) + x = tf.nn.dropout(x, rate=self.dropout if training else 0) # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) x = tf.transpose(x, perm=(1, 0, 2)) - assert len(encoder_hidden_states.shape) == 3, "encoder_hidden_states must be a 3D tensor" + assert len(shape_list(encoder_hidden_states)) == 3, "encoder_hidden_states must be a 3D tensor" encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2)) # decoder layers @@ -623,7 +638,7 @@ def _reorder_buffer(attn_cache, new_order): return attn_cache -class TFAttention(tf.keras.layers.Layer): +class TFAttention(Layer): """Multi-headed attention from "Attention Is All You Need""" def __init__( @@ -678,8 +693,10 @@ def call( (default: None). """ static_kv = self.encoder_decoder_attention # value=key=encoder_hidden_states, - tgt_len, bsz, embed_dim = query.shape - assert embed_dim == self.embed_dim, f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {query.shape}" + tgt_len, bsz, embed_dim = shape_list(query) + assert ( + embed_dim == self.embed_dim + ), f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {shape_list(query)}" # get here for encoder decoder cause of static_kv if layer_state is not None: # get the last k and v for reuse saved_state = layer_state.get(self.cache_key, {}) @@ -718,7 +735,7 @@ def call( ) # Compute multi-headed attention - src_len = k.shape[1] + src_len = shape_list(k)[1] attn_weights = tf.matmul(q, k, transpose_b=True) # shape (bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: @@ -770,7 +787,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, of def call(self, input_ids: tf.Tensor, use_cache=False): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] + bsz, seq_len = shape_list(input_ids)[:2] if use_cache: positions = tf.fill((1, 1), seq_len - 1) @@ -780,6 +797,56 @@ def call(self, input_ids: tf.Tensor, use_cache=False): return super().call(positions + self.offset) # super object is not callable for some reason +class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions, embedding_dim, **kwargs): + + if embedding_dim % 2 != 0: + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") + super().__init__( + num_positions, + embedding_dim, + **kwargs, + ) + + def build(self, input_shape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + super().build(input_shape) # Instantiates self.weight so it can be loaded + weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim) + self.set_weights([weight]) # overwrite self.weight to correct value + + @staticmethod + def _init_weight(n_pos, dim): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + # index 0 is all zero + position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) + position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) + # convert to tensor + table = tf.convert_to_tensor(position_enc, dtype=tf.float32) + tf.stop_gradient(table) + return table + + def call(self, input_ids, use_cache=False): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = shape_list(input_ids)[:2] + if use_cache: + positions = tf.fill((1, 1), seq_len - 1) + else: + # starts at 0, ends at 1-seq_len + positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range") + return super().call(positions) + + # Public API @@ -818,7 +885,7 @@ def _prepare_bart_decoder_inputs( pad_token_id = self.config.pad_token_id if decoder_input_ids is None: decoder_input_ids = self._shift_right(inputs) - bsz, tgt_len = decoder_input_ids.shape[:2] + bsz, tgt_len = shape_list(decoder_input_ids)[:2] if decoder_attn_mask is None: decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) else: @@ -950,16 +1017,20 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): base_model_prefix = "model" authorized_missing_keys = [ r"final_logits_bias", - r"encoder\.version", - r"decoder\.version", - "model.encoder.embed_tokens.weight", - "model.decoder.embed_tokens.weight", + ] + authorized_unexpected_keys = [ + r"model.encoder.embed_tokens.weight", + r"model.decoder.embed_tokens.weight", ] def __init__(self, config: BartConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) self.model = TFBartModel(config, name="model") self.use_cache = config.use_cache + # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. + self.final_logits_bias = self.add_weight( + name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + ) @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1050,6 +1121,7 @@ def call( return_dict=True, # TODO(SS): this may need to change to support compilation ) logits = self.model.shared(outputs.last_hidden_state, mode="linear") + logits = logits + self.final_logits_bias loss = None if labels is None else self.compute_loss(labels, logits) past = outputs.past_key_values if cast_bool_to_primitive(use_cache, self.config.use_cache) else None @@ -1096,7 +1168,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, ), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past" assert isinstance( encoder_outputs, TFBaseModelOutput - ), "encoder_outputs should be a TFBaseModelOutput, Instead got " + ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "inputs": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, @@ -1113,7 +1185,6 @@ def _reorder_cache(past, beam_idx): reordered_past = [] for layer_past in decoder_cached_states: # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = { attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() } @@ -1124,26 +1195,13 @@ def _reorder_cache(past, beam_idx): def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == 1 and self.config.force_bos_token_to_be_generated: - logits = self._force_token_id_to_be_generated(logits, self.config.bos_token_id) - elif cur_len == max_length - 1 and self.config.eos_token_id is not None: - logits = self._force_token_id_to_be_generated(logits, self.config.eos_token_id) - return logits - - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - output_list = [] - - # Is there a better way to do scores[:, [x for if x != token_id]] = -float("inf") in TF? - bs, vocab_size = scores.shape - for x in range(vocab_size): - if x != token_id: - output_list.append(tf.convert_to_tensor([-float("inf")] * bs, dtype=scores.dtype)) - else: - output_list.append(scores[:, x]) - scores = tf.stack(output_list, axis=1, name="scores") - assert scores.shape == (bs, vocab_size) - return scores + vocab_range = tf.constant(range(self.config.vocab_size)) + return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits) + elif cur_len == max_length - 1: + vocab_range = tf.constant(range(self.config.vocab_size)) + return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) + else: + return logits def get_output_embeddings(self): return self.model.shared diff --git a/src/transformers/modeling_tf_blenderbot.py b/src/transformers/modeling_tf_blenderbot.py new file mode 100644 index 000000000000..633b50ec7757 --- /dev/null +++ b/src/transformers/modeling_tf_blenderbot.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF BlenderBot model, ported from the fairseq repo.""" +from .configuration_blenderbot import BlenderbotConfig +from .file_utils import add_start_docstrings, is_tf_available +from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration +from .utils import logging + + +if is_tf_available(): + import tensorflow as tf + + +_CONFIG_FOR_DOC = "BlenderbotConfig" + +START_DOCSTRING = BART_START_DOCSTRING.replace( + "inherits from :class:`~transformers.TFPreTrainedModel`", + "inherits from :class:`~transformers.TFBartForConditionalGeneration`", +).replace("BartConfig", _CONFIG_FOR_DOC) + + +logger = logging.get_logger(__name__) + + +@add_start_docstrings("Blenderbot model for open domain dialogue", START_DOCSTRING) +class TFBlenderbotForConditionalGeneration(TFBartForConditionalGeneration): + config_class = BlenderbotConfig + + def adjust_logits_during_generation(self, logits, cur_len, max_length): + """Never predict pad_token_id. Predict when max_length is reached.""" + vocab_range = tf.constant(range(self.config.vocab_size)) + logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits) + if cur_len == max_length - 1: + logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) + return logits diff --git a/src/transformers/modeling_tf_marian.py b/src/transformers/modeling_tf_marian.py new file mode 100644 index 000000000000..9dcd5489660a --- /dev/null +++ b/src/transformers/modeling_tf_marian.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF Marian model, ported from the fairseq repo.""" + +from .configuration_marian import MarianConfig +from .file_utils import add_start_docstrings, is_tf_available +from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration +from .utils import logging + + +if is_tf_available(): + import tensorflow as tf + + +_CONFIG_FOR_DOC = "MarianConfig" + +START_DOCSTRING = BART_START_DOCSTRING.replace( + "inherits from :class:`~transformers.TFPreTrainedModel`", + "inherits from :class:`~transformers.TFBartForConditionalGeneration`", +).replace("BartConfig", _CONFIG_FOR_DOC) + + +logger = logging.get_logger(__name__) + + +@add_start_docstrings("Marian model for machine translation", START_DOCSTRING) +class TFMarianMTModel(TFBartForConditionalGeneration): + authorized_missing_keys = [ + r"model.encoder.embed_positions.weight", + r"model.decoder.embed_positions.weight", + ] + config_class = MarianConfig + + def adjust_logits_during_generation(self, logits, cur_len, max_length): + """Never predict pad_token_id. Predict when max_length is reached.""" + vocab_range = tf.constant(range(self.config.vocab_size)) + logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits) + if cur_len == max_length - 1: + logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) + return logits diff --git a/src/transformers/modeling_tf_mbart.py b/src/transformers/modeling_tf_mbart.py new file mode 100644 index 000000000000..804324a31631 --- /dev/null +++ b/src/transformers/modeling_tf_mbart.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF mBART model, originally from fairseq.""" +from .configuration_mbart import MBartConfig +from .file_utils import add_start_docstrings +from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration +from .utils import logging + + +_CONFIG_FOR_DOC = "MBartConfig" + +START_DOCSTRING = BART_START_DOCSTRING.replace( + "inherits from :class:`~transformers.TFPreTrainedModel`", + "inherits from :class:`~transformers.TFBartForConditionalGeneration`", +).replace("BartConfig", _CONFIG_FOR_DOC) + + +logger = logging.get_logger(__name__) + + +@add_start_docstrings("mBART (multilingual BART) model for machine translation", START_DOCSTRING) +class TFMBartForConditionalGeneration(TFBartForConditionalGeneration): + config_class = MBartConfig + # All the code is in src/transformers/modeling_tf_bart.py diff --git a/src/transformers/modeling_tf_pegasus.py b/src/transformers/modeling_tf_pegasus.py new file mode 100644 index 000000000000..262c7bdb28c3 --- /dev/null +++ b/src/transformers/modeling_tf_pegasus.py @@ -0,0 +1,41 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF Pegasus model, ported from the fairseq repo.""" +from .configuration_pegasus import PegasusConfig +from .file_utils import add_start_docstrings +from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration +from .utils import logging + + +_CONFIG_FOR_DOC = "PegasusConfig" + +START_DOCSTRING = BART_START_DOCSTRING.replace( + "inherits from :class:`~transformers.TFPreTrainedModel`", + "inherits from :class:`~transformers.TFBartForConditionalGeneration`", +).replace("BartConfig", _CONFIG_FOR_DOC) + + +logger = logging.get_logger(__name__) + + +@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING) +class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration): + authorized_missing_keys = [ + r"final_logits_bias", + r"model.encoder.embed_positions.weight", + r"model.decoder.embed_positions.weight", + ] + config_class = PegasusConfig + # All the code is in src/transformers/modeling_tf_bart.py diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index fb1268a51ab7..21a85e436346 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -325,6 +325,15 @@ def from_pretrained(self, *args, **kwargs): requires_tf(self) +class TFBlenderbotForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -797,6 +806,24 @@ def __init__(self, *args, **kwargs): requires_tf(self) +class TFMarianMTModel: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + +class TFMBartForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -922,6 +949,15 @@ def from_pretrained(self, *args, **kwargs): requires_tf(self) +class TFPegasusForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 3859f4348248..f19d4365e16a 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -137,7 +137,7 @@ def translate_src_text(self, **tokenizer_kwargs): ) self.assertEqual(self.model.device, model_inputs.input_ids.device) generated_ids = self.model.generate( - model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2 + model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 ) generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return generated_words @@ -243,6 +243,8 @@ def test_batch_generation_ru_fr(self): @require_sentencepiece @require_tokenizers class TestMarian_MT_EN(MarianIntegrationTest): + """Cover low resource/high perplexity setting. This breaks without adjust_logits_generation overwritten""" + src = "mt" tgt = "en" src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index ff49574b264c..4efdd3b08b09 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -17,7 +17,9 @@ import tempfile import unittest -from transformers import is_tf_available +import numpy as np + +from transformers import BartConfig, BartTokenizer, is_tf_available from transformers.file_utils import cached_property from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow @@ -28,12 +30,16 @@ if is_tf_available(): import tensorflow as tf - from transformers import BartConfig, TFBartForConditionalGeneration, TFBartModel - from transformers.tokenization_bart import BartTokenizer + from transformers import TFBartForConditionalGeneration, TFBartModel + from transformers.modeling_tf_bart import TFSinusoidalPositionalEmbedding @require_tf -class ModelTester: +class TFBartModelTester: + config_cls = BartConfig + config_updates = {} + hidden_act = "gelu" + def __init__(self, parent): self.parent = parent self.batch_size = 13 @@ -45,14 +51,13 @@ def __init__(self, parent): self.num_hidden_layers = 5 self.num_attention_heads = 4 self.intermediate_size = 37 - self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1 self.max_position_embeddings = 20 self.eos_token_ids = [2] self.pad_token_id = 1 self.bos_token_id = 0 - # torch.manual_seed(0) def prepare_config_and_inputs_for_common(self): input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size) @@ -60,7 +65,7 @@ def prepare_config_and_inputs_for_common(self): input_ids = tf.concat([input_ids, eos_tensor], axis=1) input_ids = tf.clip_by_value(input_ids, 3, self.vocab_size + 1) - config = BartConfig( + config = self.config_cls( vocab_size=self.vocab_size, d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, @@ -76,6 +81,7 @@ def prepare_config_and_inputs_for_common(self): bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, decoder_start_token_id=self.pad_token_id, + **self.config_updates, ) inputs_dict = prepare_bart_inputs_dict(config, input_ids) return config, inputs_dict @@ -101,9 +107,10 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + model_tester_cls = TFBartModelTester def setUp(self): - self.model_tester = ModelTester(self) + self.model_tester = self.model_tester_cls(self) self.config_tester = ConfigTester(self, config_class=BartConfig) def test_config(self): @@ -120,7 +127,7 @@ def test_compile_tf_model(self): loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") - model_class = TFBartForConditionalGeneration + model_class = self.all_generative_model_classes[0] input_ids = { "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), @@ -354,3 +361,29 @@ def test_encoder_equiv(self): expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]) assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3) + + +@require_tf +class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase): + desired_weights = [ + [0, 0, 0, 0, 0], + [0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374], + [0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258], + ] + + def test_positional_emb_cache_logic(self): + input_ids = _long_tensor([[4, 10]]) + emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6) + no_cache = emb1(input_ids, use_cache=False) + yes_cache = emb1(input_ids, use_cache=True) + self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete! + + np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy()) + + def test_positional_emb_weights_against_marian(self): + emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512) + emb1.build(None) + weights = emb1.embeddings.numpy() + for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)): + for j in range(5): + self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3) diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py new file mode 100644 index 000000000000..df11567e41a8 --- /dev/null +++ b/tests/test_modeling_tf_blenderbot.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +from tests.test_configuration_common import ConfigTester +from tests.test_modeling_tf_bart import TFBartModelTester +from tests.test_modeling_tf_common import TFModelTesterMixin +from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available +from transformers.file_utils import cached_property +from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration + + +class ModelTester(TFBartModelTester): + config_updates = dict( + normalize_before=True, + static_position_embeddings=True, + do_blenderbot_90_layernorm=True, + normalize_embeddings=True, + ) + config_cls = BlenderbotConfig + + +@require_tf +class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase): + all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () + all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () + model_tester_cls = ModelTester + is_encoder_decoder = True + test_pruning = False + + def setUp(self): + self.model_tester = self.model_tester_cls(self) + self.config_tester = ConfigTester(self, config_class=BlenderbotConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # inputs_embeds not supported + pass + + def test_saved_model_with_hidden_states_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_saved_model_with_attentions_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_compile_tf_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") + + model_class = self.all_generative_model_classes[0] + input_ids = { + "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), + "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), + } + + # Prepare our model + model = model_class(config) + model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. + # Let's load it from the disk to be sure we can use pretrained weights + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + + outputs_dict = model(input_ids) + hidden_states = outputs_dict[0] + + # Add a dense layer on top to test integration with other keras modules + outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) + + # Compile extended model + extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) + extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + + +@is_pt_tf_cross_test +@require_tokenizers +class TFBlenderbot90MIntegrationTests(unittest.TestCase): + src_text = [ + "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?" + ] + model_name = "facebook/blenderbot-90M" + + @cached_property + def tokenizer(self): + return BlenderbotSmallTokenizer.from_pretrained(self.model_name) + + @cached_property + def model(self): + model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True) + return model + + @slow + def test_90_generation_from_long_input(self): + model_inputs = self.tokenizer(self.src_text, return_tensors="tf") + generated_ids = self.model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + num_beams=2, + use_cache=True, + ) + generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)[0] + assert generated_words in ( + "i don't know. i just feel like i'm going to throw up. it's not fun.", + "i'm not sure. i just feel like i've been feeling like i have to be in a certain place", + "i'm not sure. i just feel like i've been in a bad situation.", + ) diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py new file mode 100644 index 000000000000..a713023d4f1f --- /dev/null +++ b/tests/test_modeling_tf_marian.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import warnings + +from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available +from transformers.file_utils import cached_property +from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow + +from .test_configuration_common import ConfigTester +from .test_modeling_tf_bart import TFBartModelTester +from .test_modeling_tf_common import TFModelTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFAutoModelForSeq2SeqLM, TFMarianMTModel + + +class ModelTester(TFBartModelTester): + config_updates = dict(static_position_embeddings=True, add_bias_logits=True) + config_cls = MarianConfig + + +@require_tf +class TestTFMarianCommon(TFModelTesterMixin, unittest.TestCase): + all_model_classes = (TFMarianMTModel,) if is_tf_available() else () + all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else () + model_tester_cls = ModelTester + is_encoder_decoder = True + test_pruning = False + + def setUp(self): + self.model_tester = self.model_tester_cls(self) + self.config_tester = ConfigTester(self, config_class=MarianConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # inputs_embeds not supported + pass + + def test_saved_model_with_hidden_states_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_saved_model_with_attentions_output(self): + pass + + def test_compile_tf_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") + + model_class = self.all_generative_model_classes[0] + input_ids = { + "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), + "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), + } + + # Prepare our model + model = model_class(config) + model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. + # Let's load it from the disk to be sure we can use pre-trained weights + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + + outputs_dict = model(input_ids) + hidden_states = outputs_dict[0] + + # Add a dense layer on top to test integration with other keras modules + outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) + + # Compile extended model + extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) + extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + + +class AbstractMarianIntegrationTest(unittest.TestCase): + maxDiff = 1000 # show more chars for failing integration tests + + @classmethod + def setUpClass(cls) -> None: + cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}" + return cls + + @cached_property + def tokenizer(self) -> MarianTokenizer: + return AutoTokenizer.from_pretrained(self.model_name) + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_token_id + + @cached_property + def model(self): + warnings.simplefilter("error") + model: TFMarianMTModel = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True) + assert isinstance(model, TFMarianMTModel) + c = model.config + self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]]) + self.assertEqual(c.max_length, 512) + self.assertEqual(c.decoder_start_token_id, c.pad_token_id) + return model + + def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs): + generated_words = self.translate_src_text(**tokenizer_kwargs) + self.assertListEqual(self.expected_text, generated_words) + + def translate_src_text(self, **tokenizer_kwargs): + model_inputs = self.tokenizer.prepare_seq2seq_batch( + src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf" + ) + generated_ids = self.model.generate( + model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 + ) + generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True) + return generated_words + + +@require_sentencepiece +@require_tokenizers +@is_pt_tf_cross_test +class TestMarian_MT_EN(AbstractMarianIntegrationTest): + """Cover low resource/high perplexity setting. This breaks if pad_token_id logits not set to LARGE_NEGATIVE.""" + + src = "mt" + tgt = "en" + src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] + expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."] + + @slow + def test_batch_generation_mt_en(self): + self._assert_generated_batch_equal_expected() + + +@is_pt_tf_cross_test +@require_sentencepiece +@require_tokenizers +class TestMarian_en_zh(AbstractMarianIntegrationTest): + src = "en" + tgt = "zh" + src_text = ["My name is Wolfgang and I live in Berlin"] + expected_text = ["我叫沃尔夫冈 我住在柏林"] + + @slow + def test_batch_generation_en_zh(self): + self._assert_generated_batch_equal_expected() + + +@is_pt_tf_cross_test +@require_sentencepiece +@require_tokenizers +class TestMarian_en_ROMANCE(AbstractMarianIntegrationTest): + """Multilingual on target side.""" + + src = "en" + tgt = "ROMANCE" + src_text = [ + ">>fr<< Don't spend so much time watching TV.", + ">>pt<< Your message has been sent.", + ">>es<< He's two years older than me.", + ] + expected_text = [ + "Ne passez pas autant de temps à regarder la télé.", + "A sua mensagem foi enviada.", + "Es dos años más viejo que yo.", + ] + + @slow + def test_batch_generation_en_ROMANCE_multi(self): + self._assert_generated_batch_equal_expected() + + @slow + def test_pipeline(self): + pipeline = TranslationPipeline(self.model, self.tokenizer, framework="tf") + output = pipeline(self.src_text) + self.assertEqual(self.expected_text, [x["translation_text"] for x in output]) diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py new file mode 100644 index 000000000000..d631971c43b6 --- /dev/null +++ b/tests/test_modeling_tf_mbart.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +from tests.test_configuration_common import ConfigTester +from tests.test_modeling_tf_bart import TFBartModelTester +from tests.test_modeling_tf_common import TFModelTesterMixin +from transformers import AutoTokenizer, MBartConfig, is_tf_available +from transformers.file_utils import cached_property +from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow + + +if is_tf_available(): + + import tensorflow as tf + + from transformers import TFAutoModelForSeq2SeqLM, TFMBartForConditionalGeneration + + +class ModelTester(TFBartModelTester): + config_updates = dict(normalize_before=True, add_final_layer_norm=True) + config_cls = MBartConfig + + +@require_tf +class TestTFMBartCommon(TFModelTesterMixin, unittest.TestCase): + all_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () + all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () + model_tester_cls = ModelTester + is_encoder_decoder = True + test_pruning = False + + def setUp(self): + self.model_tester = self.model_tester_cls(self) + self.config_tester = ConfigTester(self, config_class=MBartConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # inputs_embeds not supported + pass + + def test_saved_model_with_hidden_states_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_saved_model_with_attentions_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_compile_tf_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") + + model_class = self.all_generative_model_classes[0] + input_ids = { + "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), + "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), + } + + # Prepare our model + model = model_class(config) + model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. + # Let's load it from the disk to be sure we can use pretrained weights + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + + outputs_dict = model(input_ids) + hidden_states = outputs_dict[0] + + # Add a dense layer on top to test integration with other keras modules + outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) + + # Compile extended model + extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) + extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + + +@is_pt_tf_cross_test +@require_sentencepiece +@require_tokenizers +class TestMBartEnRO(unittest.TestCase): + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + ] + expected_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + ] + model_name = "facebook/mbart-large-en-ro" + + @cached_property + def tokenizer(self): + return AutoTokenizer.from_pretrained(self.model_name) + + @cached_property + def model(self): + model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True) + return model + + def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs): + generated_words = self.translate_src_text(**tokenizer_kwargs) + self.assertListEqual(self.expected_text, generated_words) + + def translate_src_text(self, **tokenizer_kwargs): + model_inputs = self.tokenizer.prepare_seq2seq_batch( + src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf" + ) + generated_ids = self.model.generate( + model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2 + ) + generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return generated_words + + @slow + def test_batch_generation_en_ro(self): + self._assert_generated_batch_equal_expected() diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py new file mode 100644 index 000000000000..32d98bfd7bf6 --- /dev/null +++ b/tests/test_modeling_tf_pegasus.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +from transformers import AutoTokenizer, PegasusConfig, is_tf_available +from transformers.file_utils import cached_property +from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow + +from .test_configuration_common import ConfigTester +from .test_modeling_pegasus import PGE_ARTICLE, XSUM_ENTRY_LONGER +from .test_modeling_tf_bart import TFBartModelTester +from .test_modeling_tf_common import TFModelTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFAutoModelForSeq2SeqLM, TFPegasusForConditionalGeneration + + +class ModelTester(TFBartModelTester): + config_updates = dict( + normalize_before=True, + static_position_embeddings=True, + ) + hidden_act = "relu" + config_cls = PegasusConfig + + +@require_tf +class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase): + all_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () + all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () + model_tester_cls = ModelTester + is_encoder_decoder = True + test_pruning = False + + def setUp(self): + self.model_tester = self.model_tester_cls(self) + self.config_tester = ConfigTester(self, config_class=PegasusConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # inputs_embeds not supported + pass + + def test_saved_model_with_hidden_states_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_saved_model_with_attentions_output(self): + # Should be uncommented during patrick TF refactor + pass + + def test_compile_tf_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") + + model_class = self.all_generative_model_classes[0] + input_ids = { + "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), + "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), + } + + # Prepare our model + model = model_class(config) + model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. + # Let's load it from the disk to be sure we can use pretrained weights + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + + outputs_dict = model(input_ids) + hidden_states = outputs_dict[0] + + # Add a dense layer on top to test integration with other keras modules + outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) + + # Compile extended model + extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) + extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + + +@is_pt_tf_cross_test +@require_sentencepiece +@require_tokenizers +class TFPegasusIntegrationTests(unittest.TestCase): + src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER] + expected_text = [ + "California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to reduce the risk of wildfires.", + 'N-Dubz have revealed they\'re "grateful" to have been nominated for four Mobo Awards.', + ] # differs slightly from pytorch, likely due to numerical differences in linear layers + model_name = "google/pegasus-xsum" + + @cached_property + def tokenizer(self): + return AutoTokenizer.from_pretrained(self.model_name) + + @cached_property + def model(self): + model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True) + return model + + def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs): + generated_words = self.translate_src_text(**tokenizer_kwargs) + assert self.expected_text == generated_words + + def translate_src_text(self, **tokenizer_kwargs): + model_inputs = self.tokenizer.prepare_seq2seq_batch( + src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf" + ) + generated_ids = self.model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + num_beams=2, + use_cache=True, + ) + generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True) + return generated_words + + @slow + def test_batch_generation(self): + self._assert_generated_batch_equal_expected() diff --git a/utils/check_repo.py b/utils/check_repo.py index 69dfd9ef0d40..99132ab02e6a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -67,6 +67,7 @@ "xlm_prophetnet": "xlmprophetnet.rst", "xlm_roberta": "xlmroberta.rst", "bert_generation": "bertgeneration.rst", + "marian": "marian.rst", } # This is to make sure the transformers module imported is the one in the repo. @@ -148,7 +149,6 @@ def get_model_doc_files(): _ignore_modules = [ "auto", "dialogpt", - "marian", "retribert", ] doc_files = [] @@ -245,6 +245,7 @@ def check_models_are_documented(module, doc_file): def _get_model_name(module): """ Get the model name for the module defining it.""" splits = module.__name__.split("_") + # Secial case for transfo_xl if splits[-1] == "xl": return "_".join(splits[-2:])