diff --git a/docs/source/index.rst b/docs/source/index.rst index 6e2cc51c229c..2d1385af0f0a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -126,7 +126,9 @@ conversion utilities for the following models: Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 23. `Pegasus `_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -24. `Other community models `_, contributed by the `community +24. `MBart `_ (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov + Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. +25. `Other community models `_, contributed by the `community `_. .. toctree:: @@ -208,6 +210,7 @@ conversion utilities for the following models: model_doc/mobilebert model_doc/dpr model_doc/pegasus + model_doc/mbart internal/modeling_utils internal/tokenization_utils internal/pipelines_utils diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index d89e788f191b..f6500438b110 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -12,7 +12,9 @@ are common among all the models to: - prune the attention heads of the model. The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin` -(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models). +(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or +for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models) and +:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) ``PreTrainedModel`` @@ -46,4 +48,8 @@ The other methods that are common to each model are defined in :class:`~transfor Generative models ~~~~~~~~~~~~~~~~~ -Coming soon +.. autoclass:: transformers.generation_utils.GenerationMixin + :members: + +.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin + :members: \ No newline at end of file diff --git a/docs/source/model_doc/bart.rst b/docs/source/model_doc/bart.rst index 81d138232dd4..42318003cd25 100644 --- a/docs/source/model_doc/bart.rst +++ b/docs/source/model_doc/bart.rst @@ -49,13 +49,6 @@ BartTokenizer :members: -MBartTokenizer -~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.MBartTokenizer - :members: build_inputs_with_special_tokens, prepare_seq2seq_batch - - BartModel ~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst new file mode 100644 index 000000000000..7305fce94126 --- /dev/null +++ b/docs/source/model_doc/mbart.rst @@ -0,0 +1,37 @@ +MBart +---------------------------------------------------- +**DISCLAIMER:** If you see something strange, +file a `Github Issue `__ and assign +@sshleifer + +Overview +~~~~~~~~~~~~~~~~~~~~~ +The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation `_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov +Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. According to the abstract, + +MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text. + +The Authors' code can be found `here `__ + + +MBartConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartConfig + :members: + + +MBartTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartTokenizer + :members: build_inputs_with_special_tokens, prepare_seq2seq_batch + + +MBartForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartForConditionalGeneration + :members: generate, forward + + diff --git a/docs/source/training.rst b/docs/source/training.rst index 7fe1498472bc..5d0cbe982bbb 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -282,7 +282,7 @@ your own ``compute_metrics`` function and pass it to the trainer. .. code-block:: python - from sklearn.metrics import precision_recall_fscore_support + from sklearn.metrics import accuracy_score, precision_recall_fscore_support def compute_metrics(pred): labels = pred.label_ids diff --git a/examples/bert-loses-patience/test_run_glue_with_pabee.py b/examples/bert-loses-patience/test_run_glue_with_pabee.py index bce18544940c..e626d220c683 100644 --- a/examples/bert-loses-patience/test_run_glue_with_pabee.py +++ b/examples/bert-loses-patience/test_run_glue_with_pabee.py @@ -20,7 +20,7 @@ def get_setup_file(): return args.f -def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"): +def clean_test_dir(path): shutil.rmtree(path, ignore_errors=True) @@ -37,7 +37,6 @@ def test_run_glue(self): --task_name mrpc --do_train --do_eval - --output_dir ./tests/fixtures/tests_samples/temp_dir --per_gpu_train_batch_size=2 --per_gpu_eval_batch_size=1 --learning_rate=2e-5 @@ -46,10 +45,13 @@ def test_run_glue(self): --overwrite_output_dir --seed=42 --max_seq_length=128 - """.split() + """ + output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) + testargs += "--output_dir " + output_dir + testargs = testargs.split() with patch.object(sys, "argv", testargs): result = run_glue_with_pabee.main() for value in result.values(): self.assertGreaterEqual(value, 0.75) - clean_test_dir() + clean_test_dir(output_dir) diff --git a/examples/test_examples.py b/examples/test_examples.py index aa7af4bf2c7d..ea4117d6fd54 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -52,7 +52,7 @@ def get_setup_file(): return args.f -def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"): +def clean_test_dir(path): shutil.rmtree(path, ignore_errors=True) @@ -68,7 +68,6 @@ def test_run_glue(self): --task_name mrpc --do_train --do_eval - --output_dir ./tests/fixtures/tests_samples/temp_dir --per_device_train_batch_size=2 --per_device_eval_batch_size=1 --learning_rate=1e-4 @@ -77,13 +76,16 @@ def test_run_glue(self): --overwrite_output_dir --seed=42 --max_seq_length=128 - """.split() + """ + output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) + testargs += "--output_dir " + output_dir + testargs = testargs.split() with patch.object(sys, "argv", testargs): result = run_glue.main() del result["eval_loss"] for value in result.values(): self.assertGreaterEqual(value, 0.75) - clean_test_dir() + clean_test_dir(output_dir) def test_run_pl_glue(self): stream_handler = logging.StreamHandler(sys.stdout) @@ -96,13 +98,15 @@ def test_run_pl_glue(self): --task mrpc --do_train --do_predict - --output_dir ./tests/fixtures/tests_samples/temp_dir --train_batch_size=32 --learning_rate=1e-4 --num_train_epochs=1 --seed=42 --max_seq_length=128 - """.split() + """ + output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) + testargs += "--output_dir " + output_dir + testargs = testargs.split() if torch.cuda.is_available(): testargs += ["--fp16", "--gpus=1"] @@ -119,7 +123,7 @@ def test_run_pl_glue(self): # for k, v in result.items(): # self.assertGreaterEqual(v, 0.75, f"({k})") # - clean_test_dir() + clean_test_dir(output_dir) def test_run_language_modeling(self): stream_handler = logging.StreamHandler(sys.stdout) @@ -133,17 +137,19 @@ def test_run_language_modeling(self): --line_by_line --train_data_file ./tests/fixtures/sample_text.txt --eval_data_file ./tests/fixtures/sample_text.txt - --output_dir ./tests/fixtures/tests_samples/temp_dir --overwrite_output_dir --do_train --do_eval --num_train_epochs=1 --no_cuda - """.split() + """ + output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) + testargs += "--output_dir " + output_dir + testargs = testargs.split() with patch.object(sys, "argv", testargs): result = run_language_modeling.main() self.assertLess(result["perplexity"], 35) - clean_test_dir() + clean_test_dir(output_dir) def test_run_squad(self): stream_handler = logging.StreamHandler(sys.stdout) @@ -154,7 +160,6 @@ def test_run_squad(self): --model_type=distilbert --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad --data_dir=./tests/fixtures/tests_samples/SQUAD - --output_dir=./tests/fixtures/tests_samples/temp_dir --max_steps=10 --warmup_steps=2 --do_train @@ -165,12 +170,15 @@ def test_run_squad(self): --per_gpu_eval_batch_size=1 --overwrite_output_dir --seed=42 - """.split() + """ + output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) + testargs += "--output_dir " + output_dir + testargs = testargs.split() with patch.object(sys, "argv", testargs): result = run_squad.main() self.assertGreaterEqual(result["f1"], 25) self.assertGreaterEqual(result["exact"], 21) - clean_test_dir() + clean_test_dir(output_dir) def test_generation(self): stream_handler = logging.StreamHandler(sys.stdout) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b9ddf890ebe2..781737eaf05f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -22,7 +22,7 @@ # Configurations from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig -from .configuration_bart import BartConfig, MBartConfig +from .configuration_bart import BartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -34,6 +34,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_marian import MarianConfig +from .configuration_mbart import MBartConfig from .configuration_mmbt import MMBTConfig from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig @@ -131,7 +132,7 @@ # Tokenizers from .tokenization_albert import AlbertTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer -from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer +from .tokenization_bart import BartTokenizer, BartTokenizerFast from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer @@ -149,6 +150,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast +from .tokenization_mbart import MBartTokenizer from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_pegasus import PegasusTokenizer @@ -298,6 +300,7 @@ BartForQuestionAnswering, BART_PRETRAINED_MODEL_ARCHIVE_LIST, ) + from .modeling_mbart import MBartForConditionalGeneration from .modeling_marian import MarianMTModel from .tokenization_marian import MarianTokenizer from .modeling_roberta import ( diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 1b1d32a019db..62090e931ad7 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -19,7 +19,7 @@ from collections import OrderedDict from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig -from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig +from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -30,6 +30,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_marian import MarianConfig +from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from .configuration_mobilebert import MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_pegasus import PegasusConfig @@ -52,6 +53,7 @@ for pretrained_map in [ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BART_PRETRAINED_CONFIG_ARCHIVE_MAP, + MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 42d69ba44644..8edcb4e69cee 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -32,6 +32,7 @@ "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", "yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json", } + BART_CONFIG_ARGS_DOC = r""" Args: vocab_size (:obj:`int`, optional, defaults to 50265): @@ -209,8 +210,3 @@ def is_valid_mbart(self) -> bool: if self.normalize_before or self.add_final_layer_norm or self.scale_embedding: logger.info("This configuration is a mixture of MBART and BART settings") return False - - -class MBartConfig(BartConfig): - model_type = "mbart" - """See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json.""" diff --git a/src/transformers/configuration_mbart.py b/src/transformers/configuration_mbart.py new file mode 100644 index 000000000000..a01fef69143a --- /dev/null +++ b/src/transformers/configuration_mbart.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MBART configuration """ + +import logging + +from .configuration_bart import BartConfig + + +logger = logging.getLogger(__name__) + +MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", + "facebook/mbart-large-cc25": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-cc25/config.json", +} + + +class MBartConfig(BartConfig): + model_type = "mbart" + """See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json.""" diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f71e2c2c0540..2ccc0492acb0 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -91,7 +91,7 @@ class PretrainedConfig(object): keep for top-k-filtering that will be used by default in the :obj:`generate` method of the model. - **top_p** (:obj:`float`, `optional`, defaults to 1) -- Value that will be used by default in the :obj:`generate` method of the model for ``top_p``. If set to float < 1, only the most probable tokens - with probabilities that add up to ``top_p`` or highest are kept for generation. + with probabilities that add up to ``top_p`` or higher are kept for generation. - **repetition_penalty** (:obj:`float`, `optional`, defaults to 1) -- Parameter for repetition penalty that will be used by default in the :obj:`generate` method of the model. 1.0 means no penalty. - **length_penalty** (:obj:`float`, `optional`, defaults to 1) -- Exponential penalty to the length that diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 41ef2f51e008..4e7672594744 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -25,10 +25,15 @@ class TFGenerationMixin: """ - A class contraining all of the functions supporting generation, to be used as a mixin in TFPreTrainedModel. + A class contraining all of the functions supporting generation, to be used as a mixin in + :class:`~transfomers.TFPreTrainedModel`. """ def prepare_inputs_for_generation(self, inputs, **kwargs): + """ + Implement in subclasses of :class:`~transfomers.TFPreTrainedModel` for custom behavior to prepare inputs in the + generate method. + """ return {"inputs": inputs} def _use_cache(self, outputs, use_cache): @@ -62,87 +67,83 @@ def generate( decoder_start_token_id=None, use_cache=None, ): - r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling - and beam-search. + r""" + Generates sequences for models with a language modeling head. The method currently supports greedy decoding, + beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. - Adapted in part from `Facebook's XLM beam search code`_. + Adapted in part from `Facebook's XLM beam search code + `__. - .. _`Facebook's XLM beam search code`: - https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529 + Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the + attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values + indicated are the default values of those config. + Most of these parameters are explained in more detail in `this blog post + `__. Parameters: - input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)` - The sequence used as a prompt for the generation. If `None` the method initializes - it as an empty `tf.Tensor` of shape `(1,)`. - - max_length: (`optional`) int - The max length of the sequence to be generated. Between 1 and infinity. Default to 20. - - min_length: (`optional`) int - The min length of the sequence to be generated. Between 0 and infinity. Default to 0. - do_sample: (`optional`) bool - If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. - - early_stopping: (`optional`) bool - if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. - - num_beams: (`optional`) int - Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. - - temperature: (`optional`) float - The value used to module the next token probabilities. Must be strictely positive. Default to 1.0. - - top_k: (`optional`) int - The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. - - top_p: (`optional`) float - The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. - - repetition_penalty: (`optional`) float - The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. - - bos_token_id: (`optional`) int - Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist. - - pad_token_id: (`optional`) int - Pad token. Defaults to pad_token_id as defined in the models config. - - eos_token_id: (`optional`) int - EOS token. Defaults to eos_token_id as defined in the models config. - - length_penalty: (`optional`) float - Exponential penalty to the length. Default to 1. - - no_repeat_ngram_size: (`optional`) int - If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. - - bad_words_ids: (`optional`) list of lists of int - `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. - - num_return_sequences: (`optional`) int - The number of independently computed returned sequences for each element in the batch. Default to 1. - - attention_mask (`optional`) obj: `tf.Tensor` with `dtype=tf.int32` of same shape as `input_ids` - Mask to avoid performing attention on padding token indices. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - Defaults to `None`. + input_ids (:obj:`tf.Tensor` of :obj:`dtype=tf.int32` and shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes + it as an empty :obj:`tf.Tensor` of shape :obj:`(1,)`. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + min_length (:obj:`int`, `optional`, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (:obj:`float`, `optional`, defaults tp 1.0): + The value used to module the next token probabilities. + top_k (:obj:`int`, `optional`, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (:obj:`float`, `optional`, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or + higher are kept for generation. + repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See `this paper + `__ for more details. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + bos_token_id (:obj:`int`, `optional`): + The id of the `beginning-of-sequence` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in + order to encourage the model to produce longer sequences. + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + bad_words_ids(:obj:`List[int]`, `optional`): + List of token ids that are not allowed to be generated. In order to get the tokens of the words that + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + attention_mask (:obj:`tf.Tensor` of :obj:`dtype=tf.int32` and shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for + tokens that are not masked, and 0 for masked tokens. + + If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. `What are attention masks? <../glossary.html#attention-mask>`__ - - decoder_start_token_id=None: (`optional`) int - If an encoder-decoder model starts decoding with a different token than BOS. - Defaults to `None` and is changed to `BOS` later. - - use_cache: (`optional`) bool - If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`. + decoder_start_token_id (:obj:`int`, `optional`): + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + model_specific_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Return: - output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)` - sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id` + :obj:`tf.Tensor` of :obj:`dtype=tf.int32` and shape :obj:`(batch_size * num_return_sequences, sequence_length)`: + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6a0d7ad0200d..6b0e09d7d1d4 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -27,13 +27,22 @@ class GenerationMixin: """ - A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel. + A class contraining all of the functions supporting generation, to be used as a mixin in + :class:`~transfomers.PreTrainedModel`. """ def prepare_inputs_for_generation(self, input_ids, **kwargs): + """ + Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the + generate method. + """ return {"input_ids": input_ids} def adjust_logits_during_generation(self, logits, **kwargs): + """ + Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in + the generate method. + """ return logits def _use_cache(self, outputs, use_cache): @@ -45,7 +54,9 @@ def _use_cache(self, outputs, use_cache): return True def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): - """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ + """ + Enforce the repetition penalty (from the `CTRL paper `__). + """ for i in range(batch_size * num_beams): for previous_token in set(prev_output_tokens[i].tolist()): # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability @@ -123,89 +134,83 @@ def generate( use_cache: Optional[bool] = None, **model_specific_kwargs ) -> torch.LongTensor: - r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + r""" + Generates sequences for models with a language modeling head. The method currently supports greedy decoding, + beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. - Adapted in part from `Facebook's XLM beam search code`_. + Adapted in part from `Facebook's XLM beam search code + `__. - .. _`Facebook's XLM beam search code`: - https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529 + Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the + attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values + indicated are the default values of those config. + Most of these parameters are explained in more detail in `this blog post + `__. Parameters: - input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` - The sequence used as a prompt for the generation. If `None` the method initializes - it as an empty `torch.LongTensor` of shape `(1,)`. - - max_length: (`optional`) int - The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20. - - min_length: (`optional`) int - The min length of the sequence to be generated. Between 0 and infinity. Default to 0. - - do_sample: (`optional`) bool - If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. - - early_stopping: (`optional`) bool - if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. - - num_beams: (`optional`) int - Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. - - temperature: (`optional`) float - The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. - - top_k: (`optional`) int - The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. - - top_p: (`optional`) float - The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. - - repetition_penalty: (`optional`) float - The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. - - pad_token_id: (`optional`) int - Padding token. Default to specicic model pad_token_id or None if it does not exist. - - bos_token_id: (`optional`) int - BOS token. Defaults to `bos_token_id` as defined in the models config. - - eos_token_id: (`optional`) int - EOS token. Defaults to `eos_token_id` as defined in the models config. - - length_penalty: (`optional`) float - Exponential penalty to the length. Default to 1. - - no_repeat_ngram_size: (`optional`) int - If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. - bad_words_ids: (`optional`) list of lists of int - `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. - - num_return_sequences: (`optional`) int - The number of independently computed returned sequences for each element in the batch. Default to 1. - - attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids` - Mask to avoid performing attention on padding token indices. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - Defaults to `None`. + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes + it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + min_length (:obj:`int`, `optional`, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (:obj:`float`, `optional`, defaults tp 1.0): + The value used to module the next token probabilities. + top_k (:obj:`int`, `optional`, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (:obj:`float`, `optional`, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or + higher are kept for generation. + repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See `this paper + `__ for more details. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + bos_token_id (:obj:`int`, `optional`): + The id of the `beginning-of-sequence` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in + order to encourage the model to produce longer sequences. + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + bad_words_ids(:obj:`List[int]`, `optional`): + List of token ids that are not allowed to be generated. In order to get the tokens of the words that + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for + tokens that are not masked, and 0 for masked tokens. + + If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. `What are attention masks? <../glossary.html#attention-mask>`__ - - decoder_start_token_id=None: (`optional`) int - If an encoder-decoder model starts decoding with a different token than BOS. - Defaults to `None` and is changed to `BOS` later. - - use_cache: (`optional`) bool - If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`. - - model_specific_kwargs: (`optional`) dict - Additional model specific kwargs will be forwarded to the `forward` function of the model. + decoder_start_token_id (:obj:`int`, `optional`): + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + model_specific_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Return: - output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` - sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id` + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: @@ -372,11 +377,16 @@ def generate( if self.config.is_encoder_decoder: if decoder_start_token_id is None: - decoder_start_token_id = bos_token_id + # see if BOS token can be used for decoder_start_token_id + if bos_token_id is not None: + decoder_start_token_id = bos_token_id + elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): + decoder_start_token_id = self.config.decoder.bos_token_id + else: + raise ValueError( + "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + ) - assert ( - decoder_start_token_id is not None - ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 7a45267703aa..02088f565c8f 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -32,6 +32,7 @@ FlaubertConfig, GPT2Config, LongformerConfig, + MBartConfig, MobileBertConfig, OpenAIGPTConfig, PegasusConfig, @@ -116,6 +117,7 @@ LongformerModel, ) from .modeling_marian import MarianMTModel +from .modeling_mbart import MBartForConditionalGeneration from .modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, @@ -289,6 +291,7 @@ (T5Config, T5ForConditionalGeneration), (PegasusConfig, PegasusForConditionalGeneration), (MarianConfig, MarianMTModel), + (MBartConfig, MBartForConditionalGeneration), (BartConfig, BartForConditionalGeneration), (EncoderDecoderConfig, EncoderDecoderModel), ] diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 0ee5f2962b6d..664b4181f5b7 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -287,6 +287,8 @@ def forward( **kwargs_decoder, ) + # TODO(PVP): currently it is not possible to use `past` + # with the encoder/decoder framework -> should be implemented return decoder_outputs + encoder_outputs def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): @@ -299,15 +301,24 @@ def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwarg encoder_outputs = (past,) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) - - return { + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { "attention_mask": attention_mask, - "decoder_attention_mask": decoder_inputs["attention_mask"], + "decoder_attention_mask": decoder_attention_mask, "decoder_input_ids": decoder_inputs["input_ids"], "encoder_outputs": encoder_outputs, } + # Ideally all models should have a `use_cache` + # leave following to ifs until all have it implemented + if "use_cache" in decoder_inputs: + input_dict["decoder_use_cache"] = decoder_inputs["use_cache"] + + if "past_key_values" in decoder_inputs: + input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"] + + return input_dict + def _reorder_cache(self, past, beam_idx): - # as a default encoder-decoder models do not re-order the past. - # TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder - return past + # apply decoder cache reordering here + return self.decoder._reorder_cache(past, beam_idx) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index ea23a819d547..7f44f8a5ac3e 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): class Attention(nn.Module): - def __init__(self, nx, n_ctx, config, scale=False): + def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): super().__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) @@ -131,8 +131,12 @@ def __init__(self, nx, n_ctx, config, scale=False): self.n_head = config.n_head self.split_size = n_state self.scale = scale - - self.c_attn = Conv1D(n_state * 3, nx) + self.is_cross_attention = is_cross_attention + if self.is_cross_attention: + self.c_attn = Conv1D(2 * n_state, nx) + self.q_attn = Conv1D(n_state, nx) + else: + self.c_attn = Conv1D(3 * n_state, nx) self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -160,8 +164,11 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= if self.scale: w = w / (float(v.size(-1)) ** 0.5) nd, ns = w.size(-2), w.size(-1) - mask = self.bias[:, :, ns - nd : ns, :ns] - w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + mask = self.bias[:, :, ns - nd : ns, :ns] + w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) if attention_mask is not None: # Apply the attention mask @@ -193,10 +200,26 @@ def split_heads(self, x, k=False): return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def forward( - self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, ): - x = self.c_attn(x) - query, key, value = x.split(self.split_size, dim=2) + if encoder_hidden_states is not None: + assert hasattr( + self, "q_attn" + ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) @@ -239,32 +262,64 @@ def forward(self, x): class Block(nn.Module): def __init__(self, n_ctx, config, scale=False): super().__init__() - nx = config.n_embd - inner_dim = config.n_inner if config.n_inner is not None else 4 * nx - self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) - self.attn = Attention(nx, n_ctx, config, scale) - self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + hidden_size = config.n_embd + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = Attention(hidden_size, n_ctx, config, scale) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if config.add_cross_attention: + self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = MLP(inner_dim, config) def forward( - self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, ): - output_attn = self.attn( - self.ln_1(x), + attn_outputs = self.attn( + self.ln_1(hidden_states), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) - a = output_attn[0] # output_attn: a, present, (attentions) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + hidden_states + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + assert hasattr( + self, "crossattention" + ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + cross_attn_outputs = self.crossattention( + self.ln_cross_attn(hidden_states), + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = hidden_states + attn_output + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights - x = x + a - m = self.mlp(self.ln_2(x)) - x = x + m + feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) + # residual connection + hidden_states = hidden_states + feed_forward_hidden_states - outputs = [x] + output_attn[1:] - return outputs # x, present, (attentions) + outputs = [hidden_states] + outputs + return outputs # hidden_states, present, (cross_attentions, attentions) class GPT2PreTrainedModel(PreTrainedModel): @@ -449,6 +504,8 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -506,7 +563,7 @@ def forward( # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -516,6 +573,17 @@ def forward( attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -546,6 +614,8 @@ def forward( layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) @@ -593,17 +663,21 @@ def __init__(self, config): def get_output_embeddings(self): return self.lm_head - def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) - return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]} + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + } @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="ctrl", + checkpoint="gpt2", output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC, ) @@ -616,6 +690,8 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, labels=None, use_cache=None, output_attentions=None, @@ -648,6 +724,8 @@ def forward( position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/src/transformers/modeling_mbart.py b/src/transformers/modeling_mbart.py new file mode 100644 index 000000000000..60e47fe31520 --- /dev/null +++ b/src/transformers/modeling_mbart.py @@ -0,0 +1,38 @@ +from .configuration_mbart import MBartConfig +from .file_utils import add_start_docstrings +from .modeling_bart import BartForConditionalGeneration + + +_CONFIG_FOR_DOC = "MBartConfig" +_TOKENIZER_FOR_DOC = "MBartTokenizer" + +MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/mbart-large-cc25", + "facebook/mbart-large-en-ro", + # See all multilingual BART models at https://huggingface.co/models?filter=mbart +] + +MBART_START_DOCSTRING = r""" + + This model is a PyTorch `torch.nn.Module `__ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.MBartConfig`): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + + +@add_start_docstrings( + "The BART Model with a language modeling head. Can be used for machine translation.", MBART_START_DOCSTRING +) +class MBartForConditionalGeneration(BartForConditionalGeneration): + """ + This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = MBartConfig diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index cdede96cf5bc..8fee16f1bc93 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -46,7 +46,7 @@ ) from .configuration_utils import PretrainedConfig from .tokenization_albert import AlbertTokenizer -from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer +from .tokenization_bart import BartTokenizer, BartTokenizerFast from .tokenization_bert import BertTokenizer, BertTokenizerFast from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_camembert import CamembertTokenizer @@ -57,6 +57,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast from .tokenization_marian import MarianTokenizer +from .tokenization_mbart import MBartTokenizer from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_pegasus import PegasusTokenizer diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 2348ce86d660..e72d4963b27c 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -14,13 +14,8 @@ # limitations under the License. import logging -from typing import List, Optional -from .file_utils import add_start_docstrings_to_callable from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast -from .tokenization_utils import BatchEncoding -from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING -from .tokenization_xlm_roberta import XLMRobertaTokenizer logger = logging.getLogger(__name__) @@ -55,258 +50,3 @@ class BartTokenizerFast(RobertaTokenizerFast): "vocab_file": {m: vocab_url for m in _all_bart_models}, "merges_file": {m: merges_url for m in _all_bart_models}, } - - -_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"] -SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" - -FAIRSEQ_LANGUAGE_CODES = [ - "ar_AR", - "cs_CZ", - "de_DE", - "en_XX", - "es_XX", - "et_EE", - "fi_FI", - "fr_XX", - "gu_IN", - "hi_IN", - "it_IT", - "ja_XX", - "kk_KZ", - "ko_KR", - "lt_LT", - "lv_LV", - "my_MM", - "ne_NP", - "nl_XX", - "ro_RO", - "ru_RU", - "si_LK", - "tr_TR", - "vi_VN", - "zh_CN", -] - - -class MBartTokenizer(XLMRobertaTokenizer): - """ - This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs. - Other tokenizer methods like ``encode`` do not work properly. - The tokenization method is `` `` for source language documents, and - `` ``` for target language documents. - - Examples:: - - >>> from transformers import MBartTokenizer - >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') - >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" - >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" - >>> batch: dict = tokenizer.prepare_seq2seq_batch( - ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian - ... ) - - """ - - vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} - max_model_input_sizes = {m: 1024 for m in _all_mbart_models} - pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} - - prefix_tokens: List[int] = [] - suffix_tokens: List[int] = [] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.sp_model_size = len(self.sp_model) - self.lang_code_to_id = { - code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) - } - self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} - self.cur_lang_code = self.lang_code_to_id["en_XX"] - self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset - - self.fairseq_tokens_to_ids.update(self.lang_code_to_id) - self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} - self._additional_special_tokens = list(self.lang_code_to_id.keys()) - self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks - by concatenating and adding special tokens. The special tokens depend on calling set_lang. - An MBART sequence has the following format, where ``X`` represents the sequence: - - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` - - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` - BOS is never used. - Pairs of sequences are not the expected use case, but they will be handled without a separator. - - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs to which the special tokens will be added - token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): - Optional second list of IDs for sequence pairs. - - Returns: - :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. - """ - if token_ids_1 is None: - return self.prefix_tokens + token_ids_0 + self.suffix_tokens - # We don't expect to process pairs, but leave the pair logic for API consistency - return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer ``prepare_for_model`` methods. - - Args: - token_ids_0 (:obj:`List[int]`): - List of ids. - token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): - Set to True if the token list is already formatted with special tokens for the model - - Returns: - :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - - if already_has_special_tokens: - if token_ids_1 is not None: - raise ValueError( - "You should not supply a second sequence if the provided sequence of " - "ids is already formated with special tokens for the model." - ) - return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - prefix_ones = [1] * len(self.prefix_tokens) - suffix_ones = [1] * len(self.suffix_tokens) - if token_ids_1 is None: - return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones - return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones - - @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - src_lang: str = "en_XX", - tgt_texts: Optional[List[str]] = None, - tgt_lang: str = "ro_RO", - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - truncation: bool = True, - padding: str = "longest", - return_tensors: str = "pt", - **kwargs, - ) -> BatchEncoding: - """Prepare a batch that can be passed directly to an instance of MBartModel. - - Arguments: - src_texts: (:obj:`list`): - list of documents to summarize or source language texts - src_lang: (:obj:`str`, `optional`, default='en_XX'): - default en_XX (english), the language we are translating from - tgt_texts: (:obj:`list`, `optional`): - list of tgt language texts or summaries. - tgt_lang: (:obj:`str`, `optional`, default='ro_RO'): - default ro_RO (romanian), the language we are translating to - max_length (:obj:`int`, `optional`): - Controls the maximum length for encoder inputs (documents to summarize or source language texts) - If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum - length is required by one of the truncation/padding parameters. If the model has no specific maximum - input length (like XLNet) truncation/padding to a maximum length will be deactivated. - max_target_length (:obj:`int`, `optional`): - Controls the maximum length of decoder inputs (target language texts or summaries) - If left unset or set to :obj:`None`, this will use the max_length value. - padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): - Activates and controls padding. Accepts the following values: - - * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a - single sequence if provided). - * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the - maximum acceptable input length for the model if that argument is not provided. - * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of - different lengths). - return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): - If set, will return tensors instead of list of python integers. Acceptable values are: - - * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. - * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. - * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. - truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): - Activates and controls truncation. Accepts the following values: - - * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument - :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not - provided. This will truncate token by token, removing a token from the longest sequence in the pair - if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to - the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with - sequence lengths greater than the model maximum admissible input size). - - Return: - :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - - - **input_ids** -- List of token ids to be fed to the encoder. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **decoder_input_ids** -- List of token ids to be fed to the decoder. - - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. - This does not include causal mask, which is built by the model. - - The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, - will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. - - """ - if max_length is None: - max_length = self.max_len - self.set_src_lang_special_tokens(src_lang) - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - self.set_tgt_lang_special_tokens(tgt_lang) - decoder_inputs: BatchEncoding = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=True, - **kwargs, - ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - - self.set_src_lang_special_tokens(src_lang) # sets to src_lang - return model_inputs - - def set_src_lang_special_tokens(self, src_lang) -> None: - """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" - self.cur_lang_code = self.lang_code_to_id[src_lang] - self.prefix_tokens = [] - self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] - - def set_tgt_lang_special_tokens(self, lang: str) -> None: - """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" - self.cur_lang_code = self.lang_code_to_id[lang] - self.prefix_tokens = [self.cur_lang_code] - self.suffix_tokens = [self.eos_token_id] diff --git a/src/transformers/tokenization_mbart.py b/src/transformers/tokenization_mbart.py new file mode 100644 index 000000000000..e74a765340f4 --- /dev/null +++ b/src/transformers/tokenization_mbart.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Optional + +from .file_utils import add_start_docstrings_to_callable +from .tokenization_utils import BatchEncoding +from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from .tokenization_xlm_roberta import XLMRobertaTokenizer + + +logger = logging.getLogger(__name__) + +_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"] +SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" + +FAIRSEQ_LANGUAGE_CODES = [ + "ar_AR", + "cs_CZ", + "de_DE", + "en_XX", + "es_XX", + "et_EE", + "fi_FI", + "fr_XX", + "gu_IN", + "hi_IN", + "it_IT", + "ja_XX", + "kk_KZ", + "ko_KR", + "lt_LT", + "lv_LV", + "my_MM", + "ne_NP", + "nl_XX", + "ro_RO", + "ru_RU", + "si_LK", + "tr_TR", + "vi_VN", + "zh_CN", +] + + +class MBartTokenizer(XLMRobertaTokenizer): + """ + This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs. + Other tokenizer methods like ``encode`` do not work properly. + The tokenization method is `` `` for source language documents, and + `` ``` for target language documents. + + Examples:: + + >>> from transformers import MBartTokenizer + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> batch: dict = tokenizer.prepare_seq2seq_batch( + ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian + ... ) + + """ + + vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} + max_model_input_sizes = {m: 1024 for m in _all_mbart_models} + pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + self.cur_lang_code = self.lang_code_to_id["en_XX"] + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + self._additional_special_tokens = list(self.lang_code_to_id.keys()) + self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. The special tokens depend on calling set_lang. + An MBART sequence has the following format, where ``X`` represents the sequence: + - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` + - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` + BOS is never used. + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` methods. + + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro_RO", + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + truncation: bool = True, + padding: str = "longest", + return_tensors: str = "pt", + **kwargs, + ) -> BatchEncoding: + """Prepare a batch that can be passed directly to an instance of MBartModel. + + Arguments: + src_texts: (:obj:`list`): + list of documents to summarize or source language texts + src_lang: (:obj:`str`, `optional`, default='en_XX'): + default en_XX (english), the language we are translating from + tgt_texts: (:obj:`list`, `optional`): + list of tgt language texts or summaries. + tgt_lang: (:obj:`str`, `optional`, default='ro_RO'): + default ro_RO (romanian), the language we are translating to + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries) + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + + """ + if max_length is None: + max_length = self.max_len + self.set_src_lang_special_tokens(src_lang) + model_inputs: BatchEncoding = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + self.set_tgt_lang_special_tokens(tgt_lang) + decoder_inputs: BatchEncoding = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=True, + **kwargs, + ) + for k, v in decoder_inputs.items(): + model_inputs[f"decoder_{k}"] = v + + self.set_src_lang_special_tokens(src_lang) # sets to src_lang + return model_inputs + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" + self.cur_lang_code = self.lang_code_to_id[src_lang] + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" + self.cur_lang_code = self.lang_code_to_id[lang] + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index cbe9b34beeff..3121980c0d5b 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -207,10 +207,10 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert) if special_tokens: - self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens))) + self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens))) else: # Or on the newly added tokens - self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add))) + self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add))) return len(tokens_to_add) diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 5bc21ca0c40d..e56a04369bd9 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -20,10 +20,9 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device -# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented -# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest from .test_modeling_bert import BertModelTester from .test_modeling_common import ids_tensor +from .test_modeling_gpt2 import GPT2ModelTester from .test_modeling_roberta import RobertaModelTester @@ -31,6 +30,7 @@ from transformers import ( BertModel, BertLMHeadModel, + GPT2LMHeadModel, RobertaModel, RobertaForCausalLM, EncoderDecoderModel, @@ -424,3 +424,59 @@ def prepare_config_and_inputs(self): def get_pretrained_model(self): return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base") + + +class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = BertModel(config) + decoder_model = GPT2LMHeadModel(decoder_config) + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = BertModelTester(self, batch_size=13) + model_tester_decoder = GPT2ModelTester(self, batch_size=13) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_input_mask, + decoder_head_mask, + decoder_token_type_ids, + decoder_sequence_labels, + decoder_token_labels, + decoder_choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_token_type_ids": decoder_token_type_ids, + "decoder_attention_mask": decoder_input_mask, + "decoder_sequence_labels": decoder_sequence_labels, + "decoder_token_labels": decoder_token_labels, + "decoder_choice_labels": decoder_choice_labels, + "encoder_hidden_states": encoder_hidden_states, + "labels": decoder_token_labels, + } + + def get_pretrained_model(self): + return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index ebb6007e6f17..dd4ca1d304d8 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -20,7 +20,7 @@ from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester -from .test_modeling_common import ModelTesterMixin, ids_tensor +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): @@ -62,27 +62,27 @@ def __init__( scope=None, ): self.parent = parent - self.batch_size = 14 - self.seq_length = 7 - self.is_training = True - self.use_token_type_ids = True - self.use_input_mask = True - self.use_labels = True - self.use_mc_token_ids = True - self.vocab_size = 99 - self.hidden_size = 32 - self.num_hidden_layers = 5 - self.num_attention_heads = 4 - self.intermediate_size = 37 - self.hidden_act = "gelu" - self.hidden_dropout_prob = 0.1 - self.attention_probs_dropout_prob = 0, 1 - self.max_position_embeddings = 512 - self.type_vocab_size = 16 - self.type_sequence_label_size = 2 - self.initializer_range = 0.02 - self.num_labels = 3 - self.num_choices = 4 + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices self.scope = None self.bos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1 @@ -142,6 +142,35 @@ def prepare_config_and_inputs(self): choice_labels, ) + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = GPT2Model(config=config) model.to(torch_device) diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 65ab6c1fb9b4..810ceae74e61 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -11,8 +11,8 @@ import torch from transformers import ( AutoModelForSeq2SeqLM, - BartConfig, - BartForConditionalGeneration, + MBartConfig, + MBartForConditionalGeneration, BatchEncoding, AutoTokenizer, ) @@ -92,7 +92,7 @@ def test_mbart_enro_config(self): mbart_models = ["facebook/mbart-large-en-ro"] expected = {"scale_embedding": True, "output_past": True} for name in mbart_models: - config = BartConfig.from_pretrained(name) + config = MBartConfig.from_pretrained(name) self.assertTrue(config.is_valid_mbart()) for k, v in expected.items(): try: @@ -102,7 +102,7 @@ def test_mbart_enro_config(self): raise def test_mbart_fast_forward(self): - config = BartConfig( + config = MBartConfig( vocab_size=99, d_model=24, encoder_layers=2, @@ -115,7 +115,7 @@ def test_mbart_fast_forward(self): add_final_layer_norm=True, return_dict=True, ) - lm_model = BartForConditionalGeneration(config).to(torch_device) + lm_model = MBartForConditionalGeneration(config).to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) diff --git a/utils/check_repo.py b/utils/check_repo.py index afc5abd9d699..9a3154ec313b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -30,6 +30,7 @@ "test_modeling_tf_xlm_roberta.py", "test_modeling_xlm_roberta.py", "test_modeling_pegasus.py", + "test_modeling_mbart.py", ] # Update this list for models that are not documented with a comment explaining the reason it should not be. @@ -106,7 +107,6 @@ def get_model_test_files(): "test_modeling_common", "test_modeling_encoder_decoder", "test_modeling_marian", - "test_modeling_mbart", "test_modeling_tf_common", ] test_files = []