Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4bc3b3f
Fixing roberta for slow-fast tests
thomwolf Oct 26, 2020
1ce1c63
WIP getting equivalence on pipelines
thomwolf Oct 26, 2020
15350e8
slow-to-fast equivalence - working on question-answering pipeline
thomwolf Oct 27, 2020
449e346
optional FAISS tests
thomwolf Nov 2, 2020
eb375bc
Pipeline Q&A
thomwolf Nov 9, 2020
a4cb7f6
Merge branch 'master' into slow-fast-comparison-pipelines
thomwolf Nov 9, 2020
3367593
Move pipeline tests to their own test job again
thomwolf Nov 9, 2020
36e0900
update tokenizer to add sequence id methods
thomwolf Nov 9, 2020
ef4919b
update to tokenizers 0.9.4
thomwolf Nov 10, 2020
dab8168
set sentencepiecce as optional
thomwolf Nov 10, 2020
9e72b29
Merge branch 'master' into slow-fast-comparison-pipelines
thomwolf Nov 10, 2020
84bc244
clean up squad
thomwolf Nov 10, 2020
751ee69
clean up pipelines to use sequence_ids
thomwolf Nov 10, 2020
0e8d7f7
style/quality
thomwolf Nov 10, 2020
eb72b1f
wording
thomwolf Nov 10, 2020
16da2c5
Switch to use_fast = True by default
thomwolf Nov 10, 2020
0f03fdb
update tests for use_fast at True by default
thomwolf Nov 10, 2020
87cb801
fix rag tokenizer test
thomwolf Nov 10, 2020
77ee69f
removing protobuf from required dependencies
thomwolf Nov 10, 2020
1483927
fix NER test for use_fast = True by default
thomwolf Nov 10, 2020
b115646
fixing example tests (Q&A examples use slow tokenizers for now)
thomwolf Nov 10, 2020
56f77e8
protobuf in main deps extras["sentencepiece"] and example deps
thomwolf Nov 10, 2020
6894fc0
fix protobug install test
thomwolf Nov 10, 2020
2441d40
try to fix seq2seq by switching to slow tokenizers for now
thomwolf Nov 10, 2020
fc2daad
Update src/transformers/tokenization_utils_base.py
thomwolf Nov 10, 2020
002848b
Update src/transformers/tokenization_utils_base.py
thomwolf Nov 10, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
extras["retrieval"] = ["faiss-cpu", "datasets"]
extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"]

extras["tokenizers"] = ["tokenizers==0.9.2"]
extras["tokenizers"] = ["tokenizers==0.9.4"]
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]

extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
Expand Down Expand Up @@ -129,7 +129,7 @@
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.9.3",
"tokenizers == 0.9.4",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions
Expand All @@ -143,7 +143,7 @@
# for OpenAI GPT
"regex != 2019.12.17",
# for SentencePiece models
"sentencepiece == 0.1.91",
# "sentencepiece == 0.1.91",
"protobuf",
# for XLM
"sacremoses",
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import PreTrainedTokenizerBase, TruncationStrategy
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
from ...utils import logging
from .utils import DataProcessor

Expand Down Expand Up @@ -765,6 +765,7 @@ class SquadFeatures:
token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
start_position: start of the answer token index
end_position: end of the answer token index
encoding: optionally store the BatchEncoding with the fast-tokenizer alignement methods.
"""

def __init__(
Expand All @@ -784,6 +785,7 @@ def __init__(
end_position,
is_impossible,
qas_id: str = None,
encoding: BatchEncoding = None,
):
self.input_ids = input_ids
self.attention_mask = attention_mask
Expand All @@ -803,6 +805,8 @@ def __init__(
self.is_impossible = is_impossible
self.qas_id = qas_id

self.encoding = encoding


class SquadResult:
"""
Expand Down
165 changes: 135 additions & 30 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .configuration_auto import AutoConfig
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
Expand Down Expand Up @@ -1758,6 +1758,7 @@ def __call__(self, *args, **kwargs):
- **answer** (:obj:`str`) -- The answer to the question.
"""
# Set defaults values
kwargs.setdefault("padding", "longest")
kwargs.setdefault("topk", 1)
kwargs.setdefault("doc_stride", 128)
kwargs.setdefault("max_answer_len", 15)
Expand All @@ -1773,19 +1774,87 @@ def __call__(self, *args, **kwargs):

# Convert inputs to features
examples = self._args_parser(*args, **kwargs)
features_list = [
squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.MAX_LENGTH.value,
is_training=False,
tqdm_enabled=False,
)
for example in examples
]
if not self.tokenizer.is_fast:
features_list = [
squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.MAX_LENGTH.value,
is_training=False,
tqdm_enabled=False,
)
for example in examples
]
else:
features_list = []
for example in examples:
# Define the side we want to truncate / pad and the text/pair sorting
question_first = bool(self.tokenizer.padding_side == "right")

encoded_inputs = self.tokenizer(
text=example.question_text if question_first else example.context_text,
text_pair=example.context_text if question_first else example.question_text,
padding=kwargs["padding"],
truncation="only_second" if question_first else "only_first",
max_length=kwargs["max_seq_len"],
stride=kwargs["doc_stride"],
return_tensors="np",
return_token_type_ids=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_special_tokens_mask=True,
)

# When the input is too long, it's converted in a batch of inputs with overflowing tokens
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
# Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
# "num_span" is the number of output samples generated from the overflowing tokens.
num_spans = len(encoded_inputs["input_ids"])

# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
p_mask = np.asarray(
[
[tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
for span_id in range(num_spans)
]
)

# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
if self.tokenizer.cls_token_id:
cls_index = np.nonzero(encoded_inputs["input_ids"] == self.tokenizer.cls_token_id)
p_mask[cls_index] = 0

features = []
for span_idx in range(num_spans):
features.append(
SquadFeatures(
input_ids=encoded_inputs["input_ids"][span_idx],
attention_mask=encoded_inputs["attention_mask"][span_idx],
token_type_ids=encoded_inputs["token_type_ids"][span_idx],
p_mask=p_mask[span_idx].tolist(),
encoding=encoded_inputs[span_idx],
# We don't use the rest of the values - and actually
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be nice, as longer term goal :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be nice, indeed! If we plan on keeping the SquadFeatures in the library, we should probably put them in the docs and give them better docstrings

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we can remove them, I didn't spend too much time making the pipeline pretty since there is a big redesign coming soon which will likely get ride of the slow tokenizers.

cls_index=None,
token_to_orig_map={},
example_index=0,
unique_id=0,
paragraph_len=0,
token_is_max_context=0,
tokens=[],
start_position=0,
end_position=0,
is_impossible=False,
qas_id=None,
)
)
features_list.append(features)

all_answers = []
for features, example in zip(features_list, examples):
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
Expand Down Expand Up @@ -1828,20 +1897,56 @@ def __call__(self, *args, **kwargs):
start_[0] = end_[0] = 0.0

starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
char_to_word = np.array(example.char_to_word_offset)

# Convert the answer (tokens) back to the original text
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]
if not self.tokenizer.is_fast:
char_to_word = np.array(example.char_to_word_offset)

# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
question_first = bool(self.tokenizer.padding_side == "right")
enc = feature.encoding

# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
answers += [
{
"score": score.item(),
"start": enc.word_to_chars(
enc.token_to_word(s), sequence_index=1 if question_first else 0
)[0],
"end": enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
1
],
"answer": example.context_text[
enc.word_to_chars(enc.token_to_word(s), sequence_index=1 if question_first else 0)[
0
] : enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
1
]
],
}
for s, e, score in zip(starts, ends, scores)
]

if kwargs["handle_impossible_answer"]:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
Expand Down Expand Up @@ -2735,7 +2840,7 @@ def pipeline(
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = False,
use_fast: bool = True,
**kwargs
) -> Pipeline:
"""
Expand Down Expand Up @@ -2793,7 +2898,7 @@ def pipeline(
When passing a task name or a string model identifier: The specific model version to use. It can be a
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
kwargs:
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
Expand Down Expand Up @@ -308,7 +308,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if "bert-base-japanese" in str(pretrained_model_name_or_path):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

use_fast = kwargs.pop("use_fast", False)
use_fast = kwargs.pop("use_fast", True)

if config.tokenizer_class is not None:
if use_fast and not config.tokenizer_class.endswith("Fast"):
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/tokenization_roberta_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .tokenization_gpt2_fast import GPT2TokenizerFast
from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils_base import AddedToken
from .utils import logging


Expand Down Expand Up @@ -172,6 +173,32 @@ def __init__(
**kwargs,
)

@property
def mask_token(self) -> str:
"""
:obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while
not having been set.

Roberta tokenizer has a special mask token to be usble in the fill-mask pipeline. The mask token will greedily
comprise the space before the `<mask>`.
"""
if self._mask_token is None and self.verbose:
logger.error("Using mask_token, but it is not set yet.")
Comment on lines +185 to +186
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the self.verbose now that the logger is centralized and can be easily managed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No indeed, let me remove that in a follow-up PR

return None
return str(self._mask_token)

@mask_token.setter
def mask_token(self, value):
"""
Overriding the default behavior of the mask token to have it eat the space before it.

This is needed to preserve backward compatibility with all the previously used models based on Roberta.
"""
# Mask token behave like a normal word, i.e. include the space before it
# So we set lstrip to True
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
self._mask_token = value

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
Expand Down
Loading