diff --git a/examples/question-answering/run_squad.py b/examples/question-answering/run_squad.py index 59550347c275..4f8fe05a8645 100644 --- a/examples/question-answering/run_squad.py +++ b/examples/question-answering/run_squad.py @@ -730,6 +730,7 @@ def main(): args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, + use_fast=False, # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling ) model = AutoModelForQuestionAnswering.from_pretrained( args.model_name_or_path, @@ -778,7 +779,10 @@ def main(): # Load a trained model and vocabulary that you have fine-tuned model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True) - tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + + # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling + # So we use use_fast=False here for now until Fast-tokenizer-compatible-examples are out + tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, use_fast=False) model.to(args.device) # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory diff --git a/examples/question-answering/run_squad_trainer.py b/examples/question-answering/run_squad_trainer.py index d5fc0723164a..0bb357b21e8e 100644 --- a/examples/question-answering/run_squad_trainer.py +++ b/examples/question-answering/run_squad_trainer.py @@ -107,6 +107,7 @@ def main(): tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, + use_fast=False, # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling ) model = AutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, diff --git a/examples/requirements.txt b/examples/requirements.txt index 9c2704796789..1ce783440f6e 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -18,3 +18,4 @@ fire pytest conllu sentencepiece != 0.1.92 +protobuf diff --git a/examples/seq2seq/test_datasets.py b/examples/seq2seq/test_datasets.py index 625b6da347d3..4cbce79eaa92 100644 --- a/examples/seq2seq/test_datasets.py +++ b/examples/seq2seq/test_datasets.py @@ -197,7 +197,7 @@ def test_distributed_sortish_sampler_splits_indices_between_procs(self): ) @require_torch_non_multigpu_but_fix_me def test_dataset_kwargs(self, tok_name): - tokenizer = AutoTokenizer.from_pretrained(tok_name) + tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False) if tok_name == MBART_TINY: train_dataset = Seq2SeqDataset( tokenizer, diff --git a/setup.py b/setup.py index 04c51912fdc9..7e7e34661b6f 100644 --- a/setup.py +++ b/setup.py @@ -96,12 +96,12 @@ extras["retrieval"] = ["faiss-cpu", "datasets"] extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"] -extras["tokenizers"] = ["tokenizers==0.9.2"] +extras["tokenizers"] = ["tokenizers==0.9.4"] extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] -extras["sentencepiece"] = ["sentencepiece==0.1.91"] +extras["sentencepiece"] = ["sentencepiece==0.1.91", "protobuf"] extras["retrieval"] = ["faiss-cpu", "datasets"] extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"] # sphinx-rtd-theme==0.5.0 introduced big changes in the style. @@ -129,7 +129,7 @@ packages=find_packages("src"), install_requires=[ "numpy", - "tokenizers == 0.9.3", + "tokenizers == 0.9.4", # dataclasses for Python versions that don't have it "dataclasses;python_version<'3.7'", # utilities from PyPA to e.g. compare versions @@ -142,9 +142,6 @@ "tqdm >= 4.27", # for OpenAI GPT "regex != 2019.12.17", - # for SentencePiece models - "sentencepiece == 0.1.91", - "protobuf", # for XLM "sacremoses", ], diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 8c765943c217..7e988e7fdd73 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -24,10 +24,7 @@ from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import BPE, Unigram, WordPiece -# from transformers.tokenization_openai import OpenAIGPTTokenizer -from transformers.utils import sentencepiece_model_pb2 as model - -from .file_utils import requires_sentencepiece +from .file_utils import requires_protobuf, requires_sentencepiece class SentencePieceExtractor: @@ -64,12 +61,6 @@ def check_number_comma(piece: str) -> bool: return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit() -def get_proto(filename: str): - m = model.ModelProto() - m.ParseFromString(open(filename, "rb").read()) - return m - - class Converter: def __init__(self, original_tokenizer): self.original_tokenizer = original_tokenizer @@ -292,8 +283,15 @@ def converted(self) -> Tokenizer: class SpmConverter(Converter): def __init__(self, *args): + requires_protobuf(self) + super().__init__(*args) - self.proto = get_proto(self.original_tokenizer.vocab_file) + + from .utils import sentencepiece_model_pb2 as model_pb2 + + m = model_pb2.ModelProto() + m.ParseFromString(open(self.original_tokenizer.vocab_file, "rb").read()) + self.proto = m def vocab(self, proto): return [(piece.piece, piece.score) for piece in proto.pieces] diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index 89ef2e22b67f..167cf3ee48d9 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -8,7 +8,7 @@ from ...file_utils import is_tf_available, is_torch_available from ...tokenization_bert import whitespace_tokenize -from ...tokenization_utils_base import PreTrainedTokenizerBase, TruncationStrategy +from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy from ...utils import logging from .utils import DataProcessor @@ -765,6 +765,7 @@ class SquadFeatures: token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. start_position: start of the answer token index end_position: end of the answer token index + encoding: optionally store the BatchEncoding with the fast-tokenizer alignement methods. """ def __init__( @@ -784,6 +785,7 @@ def __init__( end_position, is_impossible, qas_id: str = None, + encoding: BatchEncoding = None, ): self.input_ids = input_ids self.attention_mask = attention_mask @@ -803,6 +805,8 @@ def __init__( self.is_impossible = is_impossible self.qas_id = qas_id + self.encoding = encoding + class SquadResult: """ diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index f6b63fa8962f..374b10dafabe 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -185,6 +185,15 @@ _sentencepiece_available = False +try: + import google.protobuf # noqa: F401 + + _protobuf_available = True + +except ImportError: + _protobuf_available = False + + try: import tokenizers # noqa: F401 @@ -270,6 +279,10 @@ def is_sentencepiece_available(): return _sentencepiece_available +def is_protobuf_available(): + return _protobuf_available + + def is_tokenizers_available(): return _tokenizers_available @@ -330,6 +343,14 @@ def wrapper(*args, **kwargs): """ +# docstyle-ignore +PROTOBUF_IMPORT_ERROR = """ +{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones +that match your environment. +""" + + # docstyle-ignore FAISS_IMPORT_ERROR = """ {0} requires the faiss library but it was not found in your environment. Checkout the instructions on the @@ -420,6 +441,12 @@ def requires_sentencepiece(obj): raise ImportError(SENTENCEPIECE_IMPORT_ERROR.format(name)) +def requires_protobuf(obj): + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + if not is_protobuf_available(): + raise ImportError(PROTOBUF_IMPORT_ERROR.format(name)) + + def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index f48edb060cd0..8cc533d980ee 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -32,7 +32,7 @@ from .configuration_auto import AutoConfig from .configuration_utils import PretrainedConfig -from .data import SquadExample, squad_convert_examples_to_features +from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features from .file_utils import add_end_docstrings, is_tf_available, is_torch_available from .modelcard import ModelCard from .tokenization_auto import AutoTokenizer @@ -1758,6 +1758,7 @@ def __call__(self, *args, **kwargs): - **answer** (:obj:`str`) -- The answer to the question. """ # Set defaults values + kwargs.setdefault("padding", "longest") kwargs.setdefault("topk", 1) kwargs.setdefault("doc_stride", 128) kwargs.setdefault("max_answer_len", 15) @@ -1773,19 +1774,87 @@ def __call__(self, *args, **kwargs): # Convert inputs to features examples = self._args_parser(*args, **kwargs) - features_list = [ - squad_convert_examples_to_features( - examples=[example], - tokenizer=self.tokenizer, - max_seq_length=kwargs["max_seq_len"], - doc_stride=kwargs["doc_stride"], - max_query_length=kwargs["max_question_len"], - padding_strategy=PaddingStrategy.MAX_LENGTH.value, - is_training=False, - tqdm_enabled=False, - ) - for example in examples - ] + if not self.tokenizer.is_fast: + features_list = [ + squad_convert_examples_to_features( + examples=[example], + tokenizer=self.tokenizer, + max_seq_length=kwargs["max_seq_len"], + doc_stride=kwargs["doc_stride"], + max_query_length=kwargs["max_question_len"], + padding_strategy=PaddingStrategy.MAX_LENGTH.value, + is_training=False, + tqdm_enabled=False, + ) + for example in examples + ] + else: + features_list = [] + for example in examples: + # Define the side we want to truncate / pad and the text/pair sorting + question_first = bool(self.tokenizer.padding_side == "right") + + encoded_inputs = self.tokenizer( + text=example.question_text if question_first else example.context_text, + text_pair=example.context_text if question_first else example.question_text, + padding=kwargs["padding"], + truncation="only_second" if question_first else "only_first", + max_length=kwargs["max_seq_len"], + stride=kwargs["doc_stride"], + return_tensors="np", + return_token_type_ids=True, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_special_tokens_mask=True, + ) + + # When the input is too long, it's converted in a batch of inputs with overflowing tokens + # and a stride of overlap between the inputs. If a batch of inputs is given, a special output + # "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample. + # Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping". + # "num_span" is the number of output samples generated from the overflowing tokens. + num_spans = len(encoded_inputs["input_ids"]) + + # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) + # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens) + p_mask = np.asarray( + [ + [tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)] + for span_id in range(num_spans) + ] + ) + + # keep the cls_token unmasked (some models use it to indicate unanswerable questions) + if self.tokenizer.cls_token_id: + cls_index = np.nonzero(encoded_inputs["input_ids"] == self.tokenizer.cls_token_id) + p_mask[cls_index] = 0 + + features = [] + for span_idx in range(num_spans): + features.append( + SquadFeatures( + input_ids=encoded_inputs["input_ids"][span_idx], + attention_mask=encoded_inputs["attention_mask"][span_idx], + token_type_ids=encoded_inputs["token_type_ids"][span_idx], + p_mask=p_mask[span_idx].tolist(), + encoding=encoded_inputs[span_idx], + # We don't use the rest of the values - and actually + # for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample + cls_index=None, + token_to_orig_map={}, + example_index=0, + unique_id=0, + paragraph_len=0, + token_is_max_context=0, + tokens=[], + start_position=0, + end_position=0, + is_impossible=False, + qas_id=None, + ) + ) + features_list.append(features) + all_answers = [] for features, example in zip(features_list, examples): model_input_names = self.tokenizer.model_input_names + ["input_ids"] @@ -1828,20 +1897,56 @@ def __call__(self, *args, **kwargs): start_[0] = end_[0] = 0.0 starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"]) - char_to_word = np.array(example.char_to_word_offset) - - # Convert the answer (tokens) back to the original text - answers += [ - { - "score": score.item(), - "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(), - "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(), - "answer": " ".join( - example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1] - ), - } - for s, e, score in zip(starts, ends, scores) - ] + if not self.tokenizer.is_fast: + char_to_word = np.array(example.char_to_word_offset) + + # Convert the answer (tokens) back to the original text + # Score: score from the model + # Start: Index of the first character of the answer in the context string + # End: Index of the character following the last character of the answer in the context string + # Answer: Plain text of the answer + answers += [ + { + "score": score.item(), + "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(), + "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(), + "answer": " ".join( + example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1] + ), + } + for s, e, score in zip(starts, ends, scores) + ] + else: + # Convert the answer (tokens) back to the original text + # Score: score from the model + # Start: Index of the first character of the answer in the context string + # End: Index of the character following the last character of the answer in the context string + # Answer: Plain text of the answer + question_first = bool(self.tokenizer.padding_side == "right") + enc = feature.encoding + + # Sometimes the max probability token is in the middle of a word so: + # - we start by finding the right word containing the token with `token_to_word` + # - then we convert this word in a character span with `word_to_chars` + answers += [ + { + "score": score.item(), + "start": enc.word_to_chars( + enc.token_to_word(s), sequence_index=1 if question_first else 0 + )[0], + "end": enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[ + 1 + ], + "answer": example.context_text[ + enc.word_to_chars(enc.token_to_word(s), sequence_index=1 if question_first else 0)[ + 0 + ] : enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[ + 1 + ] + ], + } + for s, e, score in zip(starts, ends, scores) + ] if kwargs["handle_impossible_answer"]: answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""}) @@ -2735,7 +2840,7 @@ def pipeline( tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, framework: Optional[str] = None, revision: Optional[str] = None, - use_fast: bool = False, + use_fast: bool = True, **kwargs ) -> Pipeline: """ @@ -2793,7 +2898,7 @@ def pipeline( When passing a task name or a string model identifier: The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. - use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`): + use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`). kwargs: Additional keyword arguments passed along to the specific pipeline init (see the documentation for the diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 93c9fbfe64a9..7e375d05986b 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -280,7 +280,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. - use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`): + use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to try to load the fast version of the tokenizer. kwargs (additional keyword arguments, `optional`): Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like @@ -308,7 +308,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): if "bert-base-japanese" in str(pretrained_model_name_or_path): return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - use_fast = kwargs.pop("use_fast", False) + use_fast = kwargs.pop("use_fast", True) if config.tokenizer_class is not None: if use_fast and not config.tokenizer_class.endswith("Fast"): diff --git a/src/transformers/tokenization_roberta_fast.py b/src/transformers/tokenization_roberta_fast.py index 3709aec944fe..696c43bf53ba 100644 --- a/src/transformers/tokenization_roberta_fast.py +++ b/src/transformers/tokenization_roberta_fast.py @@ -18,6 +18,7 @@ from .tokenization_gpt2_fast import GPT2TokenizerFast from .tokenization_roberta import RobertaTokenizer +from .tokenization_utils_base import AddedToken from .utils import logging @@ -172,6 +173,32 @@ def __init__( **kwargs, ) + @property + def mask_token(self) -> str: + """ + :obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while + not having been set. + + Roberta tokenizer has a special mask token to be usble in the fill-mask pipeline. The mask token will greedily + comprise the space before the ``. + """ + if self._mask_token is None and self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Roberta. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] if token_ids_1 is None: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 9420e553822e..a7581b70f8c6 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -182,7 +182,9 @@ def to_py_obj(obj): """ Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. """ - if isinstance(obj, (list, tuple)): + if isinstance(obj, (dict, BatchEncoding)): + return {k: to_py_obj(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): return [to_py_obj(o) for o in obj] elif is_tf_available() and isinstance(obj, tf.Tensor): return obj.numpy().tolist() @@ -216,6 +218,9 @@ class BatchEncoding(UserDict): initialization. prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to add a batch axis when converting to tensors (see :obj:`tensor_type` above). + n_sequences (:obj:`Optional[int]`, `optional`): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. """ def __init__( @@ -224,6 +229,7 @@ def __init__( encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, tensor_type: Union[None, str, TensorType] = None, prepend_batch_axis: bool = False, + n_sequences: Optional[int] = None, ): super().__init__(data) @@ -232,8 +238,22 @@ def __init__( self._encodings = encoding + if n_sequences is None and encoding is not None and len(encoding): + n_sequences = encoding[0].n_sequences + + self._n_sequences = n_sequences + self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) + @property + def n_sequences(self) -> Optional[int]: + """ + :obj:`Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this + :class:`~transformers.BatchEncoding`. Currently can be one of :obj:`None` (unknown), :obj:`1` (a single + sentence) or :obj:`2` (a pair of sentences) + """ + return self.n_sequences + @property def is_fast(self) -> bool: """ @@ -311,6 +331,27 @@ def tokens(self, batch_index: int = 0) -> List[str]: raise ValueError("tokens() is not available when using Python-based tokenizers") return self._encodings[batch_index].tokens + def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]: + """ + Return a list mapping the tokens to the id of their original sentences: + + - :obj:`None` for special tokens added around or between sequences, + - :obj:`0` for tokens corresponding to words in the first sequence, + - :obj:`1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly + encoded. + + Args: + batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch. + + Returns: + :obj:`List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens + added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their + corresponding sequence. + """ + if not self._encodings: + raise ValueError("sequence_ids() is not available when using Python-based tokenizers") + return self._encodings[batch_index].sequence_ids + def words(self, batch_index: int = 0) -> List[Optional[int]]: """ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. @@ -325,7 +366,67 @@ def words(self, batch_index: int = 0) -> List[Optional[int]]: """ if not self._encodings: raise ValueError("words() is not available when using Python-based tokenizers") - return self._encodings[batch_index].words + warnings.warn( + "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " + "but more self-explanatory `BatchEncoding.word_ids()` property.", + FutureWarning, + ) + return self.word_ids(batch_index) + + def word_ids(self, batch_index: int = 0) -> List[Optional[int]]: + """ + Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. + + Args: + batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch. + + Returns: + :obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by + the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding + word (several tokens will be mapped to the same word index if they are parts of that word). + """ + if not self._encodings: + raise ValueError("word_ids() is not available when using Python-based tokenizers") + return self._encodings[batch_index].word_ids + + def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ + Get the index of the sequence represented by the given token. In the general use case, this method returns + :obj:`0` for a single sequence or the first sequence of a pair, and :obj:`1` for the second sequence of a pair + + Can be called as: + + - ``self.token_to_sequence(token_index)`` if batch size is 1 + - ``self.token_to_sequence(batch_index, token_index)`` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., + words are defined by the user). In this case it allows to easily associate encoded tokens with provided + tokenized words. + + Args: + batch_or_token_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of + the token in the sequence. + token_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the + sequence. + + Returns: + :obj:`int`: Index of the word in the input sequence. + """ + + if not self._encodings: + raise ValueError("token_to_sequence() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_sequence(token_index) def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: """ @@ -365,9 +466,11 @@ def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = token_index = self._seq_len + token_index return self._encodings[batch_index].token_to_word(token_index) - def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]: + def word_to_tokens( + self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 + ) -> Optional[TokenSpan]: """ - Get the encoded token span corresponding to a word in the sequence of the batch. + Get the encoded token span corresponding to a word in a sequence of the batch. Token spans are returned as a :class:`~transformers.tokenization_utils_base.TokenSpan` with: @@ -376,8 +479,9 @@ def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = N Can be called as: - - ``self.word_to_tokens(word_index)`` if batch size is 1 - - ``self.word_to_tokens(batch_index, word_index)`` if batch size is greater or equal to 1 + - ``self.word_to_tokens(word_index, sequence_index: int = 0)`` if batch size is 1 + - ``self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)`` if batch size is greater or equal + to 1 This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized @@ -390,6 +494,9 @@ def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = N word_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the sequence. + sequence_index (:obj:`int`, `optional`, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided word index belongs to. Returns: Optional :class:`~transformers.tokenization_utils_base.TokenSpan` Span of tokens in the encoded sequence. @@ -407,7 +514,7 @@ def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = N batch_index = self._batch_size + batch_index if word_index < 0: word_index = self._seq_len + word_index - span = self._encodings[batch_index].word_to_tokens(word_index) + span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index) return TokenSpan(*span) if span is not None else None def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: @@ -446,7 +553,9 @@ def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = token_index = batch_or_token_index return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index))) - def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int: + def char_to_token( + self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0 + ) -> int: """ Get the index of the token in the encoded output comprising a character in the original string for a sequence of the batch. @@ -467,6 +576,9 @@ def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = No char_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the sequence. + sequence_index (:obj:`int`, `optional`, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided character index belongs to. Returns: @@ -480,9 +592,11 @@ def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = No else: batch_index = 0 char_index = batch_or_char_index - return self._encodings[batch_index].char_to_token(char_index) + return self._encodings[batch_index].char_to_token(char_index, sequence_index) - def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None) -> CharSpan: + def word_to_chars( + self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 + ) -> CharSpan: """ Get the character span in the original string corresponding to given word in a sequence of the batch. @@ -503,6 +617,9 @@ def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = No word_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the sequence. + sequence_index (:obj:`int`, `optional`, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided word index belongs to. Returns: :obj:`CharSpan` or :obj:`List[CharSpan]`: Span(s) of the associated character or characters in the string. @@ -520,9 +637,9 @@ def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = No else: batch_index = 0 word_index = batch_or_word_index - return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index))) + return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index))) - def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int: + def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: """ Get the word in the original string corresponding to a character in the original string of a sequence of the batch. @@ -543,6 +660,9 @@ def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = Non char_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the character in the original string. + sequence_index (:obj:`int`, `optional`, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided character index belongs to. Returns: @@ -556,7 +676,7 @@ def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = Non else: batch_index = 0 char_index = batch_or_char_index - return self._encodings[batch_index].char_to_word(char_index) + return self._encodings[batch_index].char_to_word(char_index, sequence_index) def convert_to_tensors( self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False @@ -1872,6 +1992,8 @@ def _save_pretrained( "Only fast tokenizers (instances of PretrainedTokenizerFast) can be saved in non legacy format." ) + save_directory = str(save_directory) + added_tokens_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE ) diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index 8552aae9d256..c672a0b02ef2 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -169,9 +169,10 @@ def _convert_encoding( return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, - ) -> Dict[str, Any]: + ) -> Tuple[Dict[str, Any], List[EncodingFast]]: """ - Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict. + Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list + of encodings, take care of building a batch from overflowing tokens. Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are lists (overflows) of lists (tokens). @@ -203,7 +204,7 @@ def _convert_encoding( if return_length: encoding_dict["length"].append(len(e.ids)) - return encoding_dict + return encoding_dict, encodings def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: """ @@ -390,9 +391,12 @@ def _batch_encode_plus( ) # Convert encoding to dict - # `Tokens` has type: List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]] + # `Tokens` has type: Tuple[ + # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]], + # List[EncodingFast] + # ] # with nested dimensions corresponding to batch, overflows, sequence length - tokens = [ + tokens_and_encodings = [ self._convert_encoding( encoding=encoding, return_token_type_ids=return_token_type_ids, @@ -406,22 +410,27 @@ def _batch_encode_plus( for encoding in encodings ] - # Convert the output to have dict[list] from list[dict] - sanitized = {} - for key in tokens[0].keys(): - # To List[List[List[int]]] of shape (batch, overflows, sequence length) - stack = [e for item in tokens for e in item[key]] - sanitized[key] = stack + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0].keys(): + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] # If returning overflowing tokens, we need to return a mapping # from the batch idx to the original sample if return_overflowing_tokens: overflow_to_sample_mapping = [] - for i, enc in enumerate(tokens): - overflow_to_sample_mapping += [i] * len(enc["input_ids"]) - sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping - return BatchEncoding(sanitized, encodings, tensor_type=return_tensors) + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) def _encode_plus( self, @@ -518,6 +527,8 @@ def _save_pretrained( Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the specific :meth:`~transformers.PreTrainedTokenizerFast._save_pretrained` """ + save_directory = str(save_directory) + if legacy_format: added_tokens_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 697df13a1781..736ac9612081 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -1,10 +1,10 @@ from typing import List, Optional +from unittest import mock from transformers import is_tf_available, is_torch_available, pipeline - -# from transformers.pipelines import DefaultArgumentHandler, Pipeline from transformers.pipelines import Pipeline from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow +from transformers.tokenization_utils_base import to_py_obj VALID_INPUTS = ["A simple string", ["list of strings"]] @@ -13,9 +13,11 @@ @is_pipeline_test class CustomInputPipelineCommonMixin: pipeline_task = None - pipeline_loading_kwargs = {} - small_models = None # Models tested without the @slow decorator - large_models = None # Models tested with the @slow decorator + pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with + pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with + small_models = [] # Models tested without the @slow decorator + large_models = [] # Models tested with the @slow decorator + valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers def setUp(self) -> None: if not is_tf_available() and not is_torch_available(): @@ -47,73 +49,11 @@ def setUp(self) -> None: @require_torch @slow def test_pt_defaults(self): - pipeline(self.pipeline_task, framework="pt") - - @require_tf - @slow - def test_tf_defaults(self): - pipeline(self.pipeline_task, framework="tf") - - @require_torch - def test_torch_small(self): - for model_name in self.small_models: - nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt") - self._test_pipeline(nlp) - - @require_tf - def test_tf_small(self): - for model_name in self.small_models: - nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf") - self._test_pipeline(nlp) - - @require_torch - @slow - def test_torch_large(self): - for model_name in self.large_models: - nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt") - self._test_pipeline(nlp) - - @require_tf - @slow - def test_tf_large(self): - for model_name in self.large_models: - nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf") - self._test_pipeline(nlp) - - def _test_pipeline(self, nlp: Pipeline): - raise NotImplementedError - - -@is_pipeline_test -class MonoInputPipelineCommonMixin: - pipeline_task = None - pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with - pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with - small_models = [] # Models tested without the @slow decorator - large_models = [] # Models tested with the @slow decorator - mandatory_keys = {} # Keys which should be in the output - valid_inputs = VALID_INPUTS # inputs which are valid - invalid_inputs = [None] # inputs which are not allowed - expected_multi_result: Optional[List] = None - expected_check_keys: Optional[List[str]] = None - - def setUp(self) -> None: - if not is_tf_available() and not is_torch_available(): - return # Currently no JAX pipelines - - for model_name in self.small_models: - pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs) - for model_name in self.large_models: - pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs) - - @require_torch - @slow - def test_pt_defaults_loads(self): pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs) @require_tf @slow - def test_tf_defaults_loads(self): + def test_tf_defaults(self): pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs) @require_torch @@ -166,6 +106,95 @@ def test_tf_large(self): ) self._test_pipeline(nlp) + def _test_pipeline(self, nlp: Pipeline): + raise NotImplementedError + + @require_torch + def test_compare_slow_fast_torch(self): + for model_name in self.small_models: + nlp_slow = pipeline( + task=self.pipeline_task, + model=model_name, + tokenizer=model_name, + framework="pt", + use_fast=False, + **self.pipeline_loading_kwargs, + ) + nlp_fast = pipeline( + task=self.pipeline_task, + model=model_name, + tokenizer=model_name, + framework="pt", + use_fast=True, + **self.pipeline_loading_kwargs, + ) + self._compare_slow_fast_pipelines(nlp_slow, nlp_fast, method="forward") + + @require_tf + def test_compare_slow_fast_tf(self): + for model_name in self.small_models: + nlp_slow = pipeline( + task=self.pipeline_task, + model=model_name, + tokenizer=model_name, + framework="tf", + use_fast=False, + **self.pipeline_loading_kwargs, + ) + nlp_fast = pipeline( + task=self.pipeline_task, + model=model_name, + tokenizer=model_name, + framework="tf", + use_fast=True, + **self.pipeline_loading_kwargs, + ) + self._compare_slow_fast_pipelines(nlp_slow, nlp_fast, method="call") + + def _compare_slow_fast_pipelines(self, nlp_slow: Pipeline, nlp_fast: Pipeline, method: str): + """We check that the inputs to the models forward passes are identical for + slow and fast tokenizers. + """ + with mock.patch.object( + nlp_slow.model, method, wraps=getattr(nlp_slow.model, method) + ) as mock_slow, mock.patch.object(nlp_fast.model, method, wraps=getattr(nlp_fast.model, method)) as mock_fast: + for inputs in self.valid_inputs: + if isinstance(inputs, dict): + inputs.update(self.pipeline_running_kwargs) + _ = nlp_slow(**inputs) + _ = nlp_fast(**inputs) + else: + _ = nlp_slow(inputs, **self.pipeline_running_kwargs) + _ = nlp_fast(inputs, **self.pipeline_running_kwargs) + + mock_slow.assert_called() + mock_fast.assert_called() + + self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list)) + for mock_slow_call_args, mock_fast_call_args in zip( + mock_slow.call_args_list, mock_slow.call_args_list + ): + slow_call_args, slow_call_kwargs = mock_slow_call_args + fast_call_args, fast_call_kwargs = mock_fast_call_args + + slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs) + fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs) + + self.assertEqual(slow_call_args, fast_call_args) + self.assertDictEqual(slow_call_kwargs, fast_call_kwargs) + + +@is_pipeline_test +class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin): + """A version of the CustomInputPipelineCommonMixin + with a predefined `_test_pipeline` method. + """ + + mandatory_keys = {} # Keys which should be in the output + invalid_inputs = [None] # inputs which are not allowed + expected_multi_result: Optional[List] = None + expected_check_keys: Optional[List[str]] = None + def _test_pipeline(self, nlp: Pipeline): self.assertIsNotNone(nlp) @@ -199,76 +228,3 @@ def _test_pipeline(self, nlp: Pipeline): self.assertIn(key, result) self.assertRaises(Exception, nlp, self.invalid_inputs) - - -# @is_pipeline_test -# class DefaultArgumentHandlerTestCase(unittest.TestCase): -# def setUp(self) -> None: -# self.handler = DefaultArgumentHandler() -# -# def test_kwargs_x(self): -# mono_data = {"X": "This is a sample input"} -# mono_args = self.handler(**mono_data) -# -# self.assertTrue(isinstance(mono_args, list)) -# self.assertEqual(len(mono_args), 1) -# -# multi_data = {"x": ["This is a sample input", "This is a second sample input"]} -# multi_args = self.handler(**multi_data) -# -# self.assertTrue(isinstance(multi_args, list)) -# self.assertEqual(len(multi_args), 2) -# -# def test_kwargs_data(self): -# mono_data = {"data": "This is a sample input"} -# mono_args = self.handler(**mono_data) -# -# self.assertTrue(isinstance(mono_args, list)) -# self.assertEqual(len(mono_args), 1) -# -# multi_data = {"data": ["This is a sample input", "This is a second sample input"]} -# multi_args = self.handler(**multi_data) -# -# self.assertTrue(isinstance(multi_args, list)) -# self.assertEqual(len(multi_args), 2) -# -# def test_multi_kwargs(self): -# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"} -# mono_args = self.handler(**mono_data) -# -# self.assertTrue(isinstance(mono_args, list)) -# self.assertEqual(len(mono_args), 2) -# -# multi_data = { -# "data": ["This is a sample input", "This is a second sample input"], -# "test": ["This is a sample input 2", "This is a second sample input 2"], -# } -# multi_args = self.handler(**multi_data) -# -# self.assertTrue(isinstance(multi_args, list)) -# self.assertEqual(len(multi_args), 4) -# -# def test_args(self): -# mono_data = "This is a sample input" -# mono_args = self.handler(mono_data) -# -# self.assertTrue(isinstance(mono_args, list)) -# self.assertEqual(len(mono_args), 1) -# -# mono_data = ["This is a sample input"] -# mono_args = self.handler(mono_data) -# -# self.assertTrue(isinstance(mono_args, list)) -# self.assertEqual(len(mono_args), 1) -# -# multi_data = ["This is a sample input", "This is a second sample input"] -# multi_args = self.handler(multi_data) -# -# self.assertTrue(isinstance(multi_args, list)) -# self.assertEqual(len(multi_args), 2) -# -# multi_data = ["This is a sample input", "This is a second sample input"] -# multi_args = self.handler(*multi_data) -# -# self.assertTrue(isinstance(multi_args, list)) -# self.assertEqual(len(multi_args), 2) diff --git a/tests/test_pipelines_dialog.py b/tests/test_pipelines_dialog.py deleted file mode 100644 index 751d4b2b3e5f..000000000000 --- a/tests/test_pipelines_dialog.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - -from transformers.pipelines import Conversation, Pipeline - -from .test_pipelines_common import CustomInputPipelineCommonMixin - - -class DialoguePipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): - pipeline_task = "conversational" - small_models = [] # Default model - Models tested without the @slow decorator - large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator - - def _test_pipeline(self, nlp: Pipeline): - valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]] - invalid_inputs = ["Hi there!", Conversation()] - self.assertIsNotNone(nlp) - - mono_result = nlp(valid_inputs[0]) - self.assertIsInstance(mono_result, Conversation) - - multi_result = nlp(valid_inputs[1]) - self.assertIsInstance(multi_result, list) - self.assertIsInstance(multi_result[0], Conversation) - # Inactive conversations passed to the pipeline raise a ValueError - self.assertRaises(ValueError, nlp, valid_inputs[1]) - - for bad_input in invalid_inputs: - self.assertRaises(Exception, nlp, bad_input) - self.assertRaises(Exception, nlp, invalid_inputs) diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_ner.py index bc12900d8422..58da4aded63e 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_ner.py @@ -146,10 +146,10 @@ def test_tf_small_ignore_subwords_available_for_fast_tokenizers(self): @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): for model_name in self.small_models: - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) with self.assertRaises(ValueError): - pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True) + pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True, use_fast=False) @require_torch def test_pt_defaults_slow_tokenizer(self): diff --git a/tests/test_pipelines_question_answering.py b/tests/test_pipelines_question_answering.py index 54b306c09d88..9b25c57342d8 100644 --- a/tests/test_pipelines_question_answering.py +++ b/tests/test_pipelines_question_answering.py @@ -8,10 +8,22 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "question-answering" + pipeline_running_kwargs = { + "padding": "max_length", + "max_seq_len": 25, + "doc_stride": 5, + } # Default is 'longest' but we use 'max_length' to test equivalence between slow/fast tokenizers small_models = [ "sshleifer/tiny-distilbert-base-cased-distilled-squad" ] # Models tested without the @slow decorator large_models = [] # Models tested with the @slow decorator + valid_inputs = [ + {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}, + { + "question": "In what field is HuggingFace working ?", + "context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.", + }, + ] def _test_pipeline(self, nlp: Pipeline): output_keys = {"score", "answer", "start", "end"} diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index 39bc2dc124cc..ae2086d426c3 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -12,6 +12,18 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" ] # Models tested without the @slow decorator large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator + valid_inputs = [ + {"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"}, + {"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]}, + {"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"}, + {"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]}, + {"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"}, + { + "sequences": "Who are you voting for in 2020?", + "candidate_labels": "politics", + "hypothesis_template": "This text is about {}", + }, + ] def _test_scores_sum_to_one(self, result): sum = 0.0 diff --git a/tests/test_retrieval_rag.py b/tests/test_retrieval_rag.py index 93774be18382..a95324535b82 100644 --- a/tests/test_retrieval_rag.py +++ b/tests/test_retrieval_rag.py @@ -9,7 +9,7 @@ import numpy as np from datasets import Dataset -import faiss +from transformers import is_faiss_available from transformers.configuration_bart import BartConfig from transformers.configuration_dpr import DPRConfig from transformers.configuration_rag import RagConfig @@ -27,6 +27,10 @@ from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES +if is_faiss_available(): + import faiss + + @require_faiss @require_datasets class RagRetrieverTest(TestCase): diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 390e89b08939..e06d7800bb1d 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -116,5 +116,5 @@ def test_parents_and_children_in_mappings(self): @require_tokenizers def test_from_pretrained_use_fast_toggle(self): - self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizer) - self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=True), BertTokenizerFast) + self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer) + self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 0090c0f47d30..376616a0b5de 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -576,6 +576,42 @@ def test_mask_output(self): sequences, mask = information["input_ids"], information["token_type_ids"] self.assertEqual(len(sequences), len(mask)) + def test_token_type_ids(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + seq_0 = "Test this method." + + # We want to have sequence 0 and sequence 1 are tagged + # respectively with 0 and 1 token_ids + # (regardeless of weither the model use token type ids) + # We use this assumption in the QA pipeline among other place + output = tokenizer(seq_0, return_token_type_ids=True) + self.assertIn(0, output["token_type_ids"]) + + def test_sequence_ids(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + if not tokenizer.is_fast: + continue + with self.subTest(f"{tokenizer.__class__.__name__}"): + seq_0 = "Test this method." + seq_1 = "With these inputs." + + # We want to have sequence 0 and sequence 1 are tagged + # respectively with 0 and 1 token_ids + # (regardeless of weither the model use token type ids) + # We use this assumption in the QA pipeline among other place + output = tokenizer(seq_0) + self.assertIn(0, output.sequence_ids()) + + output = tokenizer(seq_0, seq_1) + self.assertIn(0, output.sequence_ids()) + self.assertIn(1, output.sequence_ids()) + + if tokenizer.num_special_tokens_to_add(pair=True): + self.assertIn(None, output.sequence_ids()) + def test_number_of_added_tokens(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: @@ -1878,6 +1914,144 @@ def test_alignement_methods(self): batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1 ) + # Assert token_to_sequence + self.assertEqual(encoding.token_to_sequence(num_tokens // 2), 0) + self.assertEqual(encoding.token_to_sequence(0, num_tokens // 2), 0) + self.assertEqual(batch_encoding.token_to_sequence(1, num_tokens // 2), 0) + self.assertEqual(batch_encoding.token_to_sequence(0, num_tokens // 2), 0) + self.assertEqual(batch_encoding.token_to_sequence(last_batch_index, num_tokens // 2), 0) + + # Pair of input sequences + + words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"] + text = " ".join(words) + pair_words = ["Amazing", "example", "full", "of", "inspiration"] + pair_text = " ".join(pair_words) + batch_size = 3 + index_word_in_first_seq = words.index("inspiration") + index_word_in_pair_seq = pair_words.index("inspiration") + index_char_in_first_seq = text.find("inspiration") + index_char_in_pair_seq = pair_text.find("inspiration") + + pair_encoding = tokenizer_r.encode_plus(text, pair_text, add_special_tokens=False) + + pair_batch_encoding = tokenizer_r.batch_encode_plus( + [(text, pair_text)] * batch_size, add_special_tokens=False + ) + num_tokens = len(encoding["input_ids"]) + + last_word_index = len(words) - 1 + last_token_index = num_tokens - 1 + last_batch_index = batch_size - 1 + last_char_index = len(text) - 1 + + # Assert word_to_tokens + self.assertNotEqual( + pair_encoding.word_to_tokens(index_word_in_first_seq, sequence_index=0).start, + pair_encoding.word_to_tokens(index_word_in_pair_seq, sequence_index=1).start, + ) + self.assertEqual( + pair_encoding["input_ids"][ + pair_encoding.word_to_tokens(index_word_in_first_seq, sequence_index=0).start + ], + pair_encoding["input_ids"][ + pair_encoding.word_to_tokens(index_word_in_pair_seq, sequence_index=1).start + ], + ) + self.assertNotEqual( + pair_batch_encoding.word_to_tokens(1, index_word_in_first_seq, sequence_index=0).start, + pair_batch_encoding.word_to_tokens(1, index_word_in_pair_seq, sequence_index=1).start, + ) + self.assertEqual( + pair_batch_encoding["input_ids"][1][ + pair_batch_encoding.word_to_tokens(1, index_word_in_first_seq, sequence_index=0).start + ], + pair_batch_encoding["input_ids"][1][ + pair_batch_encoding.word_to_tokens(1, index_word_in_pair_seq, sequence_index=1).start + ], + ) + + # Assert char_to_token + self.assertNotEqual( + pair_encoding.char_to_token(index_char_in_first_seq, sequence_index=0), + pair_encoding.char_to_token(index_char_in_pair_seq, sequence_index=1), + ) + self.assertEqual( + pair_encoding["input_ids"][pair_encoding.char_to_token(index_char_in_first_seq, sequence_index=0)], + pair_encoding["input_ids"][pair_encoding.char_to_token(index_char_in_pair_seq, sequence_index=1)], + ) + self.assertNotEqual( + pair_batch_encoding.char_to_token(1, index_char_in_first_seq, sequence_index=0), + pair_batch_encoding.char_to_token(1, index_char_in_pair_seq, sequence_index=1), + ) + self.assertEqual( + pair_batch_encoding["input_ids"][1][ + pair_batch_encoding.char_to_token(1, index_char_in_first_seq, sequence_index=0) + ], + pair_batch_encoding["input_ids"][1][ + pair_batch_encoding.char_to_token(1, index_char_in_pair_seq, sequence_index=1) + ], + ) + + # Assert char_to_word + self.assertNotEqual( + pair_encoding.char_to_word(index_char_in_first_seq, sequence_index=0), + pair_encoding.char_to_word(index_char_in_pair_seq, sequence_index=1), + ) + self.assertEqual( + words[pair_encoding.char_to_word(index_char_in_first_seq, sequence_index=0)], + pair_words[pair_encoding.char_to_word(index_char_in_pair_seq, sequence_index=1)], + ) + self.assertNotEqual( + pair_batch_encoding.char_to_word(1, index_char_in_first_seq, sequence_index=0), + pair_batch_encoding.char_to_word(1, index_char_in_pair_seq, sequence_index=1), + ) + self.assertEqual( + words[pair_batch_encoding.char_to_word(1, index_char_in_first_seq, sequence_index=0)], + pair_words[pair_batch_encoding.char_to_word(1, index_char_in_pair_seq, sequence_index=1)], + ) + + # Assert word_to_chars + self.assertNotEqual( + pair_encoding.word_to_chars(index_word_in_first_seq, sequence_index=0).start, + pair_encoding.word_to_chars(index_word_in_pair_seq, sequence_index=1).start, + ) + self.assertEqual( + text[pair_encoding.word_to_chars(index_word_in_first_seq, sequence_index=0).start], + pair_text[pair_encoding.word_to_chars(index_word_in_pair_seq, sequence_index=1).start], + ) + self.assertNotEqual( + pair_batch_encoding.word_to_chars(1, index_word_in_first_seq, sequence_index=0).start, + pair_batch_encoding.word_to_chars(1, index_word_in_pair_seq, sequence_index=1).start, + ) + self.assertEqual( + text[pair_batch_encoding.word_to_chars(1, index_word_in_first_seq, sequence_index=0).start], + pair_text[pair_batch_encoding.word_to_chars(1, index_word_in_pair_seq, sequence_index=1).start], + ) + + # Assert token_to_sequence + pair_encoding = tokenizer_r.encode_plus(text, pair_text, add_special_tokens=True) + + pair_sequence_ids = [ + pair_encoding.token_to_sequence(i) for i in range(len(pair_encoding["input_ids"])) + ] + self.assertIn(0, pair_sequence_ids) + self.assertIn(1, pair_sequence_ids) + if tokenizer_r.num_special_tokens_to_add(pair=True): + self.assertIn(None, pair_sequence_ids) + + pair_batch_encoding = tokenizer_r.batch_encode_plus( + [(text, pair_text)] * batch_size, add_special_tokens=True + ) + pair_batch_sequence_ids = [ + pair_batch_encoding.token_to_sequence(1, i) + for i in range(len(pair_batch_encoding["input_ids"][0])) + ] + self.assertIn(0, pair_batch_sequence_ids) + self.assertIn(1, pair_batch_sequence_ids) + if tokenizer_r.num_special_tokens_to_add(pair=True): + self.assertIn(None, pair_batch_sequence_ids) + def test_tokenization_python_rust_equals(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)): diff --git a/tests/test_tokenization_rag.py b/tests/test_tokenization_rag.py index 158aadca6940..63bdb541e61d 100644 --- a/tests/test_tokenization_rag.py +++ b/tests/test_tokenization_rag.py @@ -4,13 +4,12 @@ import tempfile from unittest import TestCase +from transformers import BartTokenizer, BartTokenizerFast, DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast from transformers.configuration_bart import BartConfig from transformers.configuration_dpr import DPRConfig from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available -from transformers.testing_utils import require_datasets, require_faiss, require_torch, slow -from transformers.tokenization_bart import BartTokenizer +from transformers.testing_utils import require_datasets, require_faiss, require_tokenizers, require_torch, slow from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES -from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES @@ -96,6 +95,7 @@ def get_bart_tokenizer(self) -> BartTokenizer: def tearDown(self): shutil.rmtree(self.tmpdirname) + @require_tokenizers def test_save_load_pretrained_with_saved_config(self): save_dir = os.path.join(self.tmpdirname, "rag_tokenizer") @@ -104,10 +104,10 @@ def test_save_load_pretrained_with_saved_config(self): rag_config.save_pretrained(save_dir) rag_tokenizer.save_pretrained(save_dir) new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config) - self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer) - self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab) - self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer) - self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder) + self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizerFast) + self.assertEqual(new_rag_tokenizer.question_encoder.get_vocab(), rag_tokenizer.question_encoder.get_vocab()) + self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizerFast) + self.assertEqual(new_rag_tokenizer.generator.get_vocab(), rag_tokenizer.generator.get_vocab()) @slow def test_pretrained_token_nq_tokenizer(self): diff --git a/tests/test_tokenization_xlm_prophetnet.py b/tests/test_tokenization_xlm_prophetnet.py index 83097ff71d71..7dfdee6b5f8a 100644 --- a/tests/test_tokenization_xlm_prophetnet.py +++ b/tests/test_tokenization_xlm_prophetnet.py @@ -18,7 +18,7 @@ import unittest from transformers.file_utils import cached_property -from transformers.testing_utils import slow +from transformers.testing_utils import require_sentencepiece, slow from transformers.tokenization_xlm_prophetnet import SPIECE_UNDERLINE, XLMProphetNetTokenizer from .test_tokenization_common import TokenizerTesterMixin @@ -27,6 +27,7 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") +@require_sentencepiece class XLMProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = XLMProphetNetTokenizer