From 74b3d7abce96c79bf8c35517857b4032b3d85a21 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 4 Nov 2020 15:08:05 -0500 Subject: [PATCH 1/3] Clean up data collators and datasets --- examples/language-modeling/run_mlm.py | 14 +- src/transformers/__init__.py | 1 - src/transformers/data/data_collator.py | 266 ++++++------------ .../data/datasets/language_modeling.py | 26 +- src/transformers/utils/dummy_pt_objects.py | 5 - tests/test_data_collator.py | 19 +- 6 files changed, 135 insertions(+), 196 deletions(-) diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index d6653b18057d..cd1cc3f26da7 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -264,7 +264,15 @@ def main(): def tokenize_function(examples): # Remove empty lines examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] - return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length) + return tokenizer( + examples["text"], + padding=padding, + truncation=True, + max_length=data_args.max_seq_length, + # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it + # receives the `special_tokens_mask`. + return_special_tokens_mask=True, + ) tokenized_datasets = datasets.map( tokenize_function, @@ -275,8 +283,10 @@ def tokenize_function(examples): ) else: # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more + # efficient when it receives the `special_tokens_mask`. def tokenize_function(examples): - return tokenizer(examples[text_column_name]) + return tokenizer(examples[text_column_name], return_special_tokens_mask=True) tokenized_datasets = datasets.map( tokenize_function, diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ee5da4399984..d8279e8c5305 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -281,7 +281,6 @@ from .data.data_collator import ( DataCollator, DataCollatorForLanguageModeling, - DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, DataCollatorForSOP, DataCollatorForTokenClassification, diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 7eccad31f9ff..a634852252da 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,4 +1,5 @@ import random +import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union @@ -175,72 +176,111 @@ def __call__(self, features): return batch +def _collate_batch(examples, tokenizer): + """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" + # Tensorize if necessary. + if isinstance(examples[0], (list, tuple)): + examples = [torch.tensor(e, dtype=torch.long) for e in examples] + + # Check if padding is necessary. + length_of_first = examples[0].size(0) + are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) + if are_tensors_same_length: + return torch.stack(examples, dim=0) + + # If yes, check if we have a `pad_token`. + if tokenizer._pad_token is None: + raise ValueError( + "You are attempting to pad samples but the tokenizer you are using" + f" ({tokenizer.__class__.__name__}) does not have one." + ) + + # Creating the full tensor and filling it with our data. + max_length = max(x.size(0) for x in examples) + result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) + for i, example in enumerate(examples): + if tokenizer.padding_side == "right": + result[i, : example.shape[0]] = example + else: + result[i, -example.shape[0] :] = example + return result + + @dataclass class DataCollatorForLanguageModeling: """ - Data collator used for language modeling. + Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they + are not all of the same length. - - collates batches of tensors, honoring their tokenizer's pad_token - - preprocesses batches for masked language modeling + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + mlm (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the + inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for + non-masked tokens and the value to predict for the masked token. + mlm_probability (:obj:`float`, `optional`, defaults to 0.15): + The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. + + .. note:: + + For best performance, this data collator should be used with a dataset having items that are dictionaries or + BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a + :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the + argument :obj:`return_special_tokens_mask=True`. """ tokenizer: PreTrainedTokenizerBase mlm: bool = True mlm_probability: float = 0.15 + def __post_init__(self): + if self.mlm and self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. " + "You should pass `mlm=False`." + ) + def __call__( self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] ) -> Dict[str, torch.Tensor]: + # Handle dict or lists with proper padding and conversion to tensor. if isinstance(examples[0], (dict, BatchEncoding)): - examples = [e["input_ids"] for e in examples] - batch = self._tensorize_batch(examples) + batch = self.tokenizer.pad(examples, return_tensors="pt") + else: + batch = {"input_ids": _collate_batch(examples, self.tokenizer)} + + # If special token mask has been preprocessed, pop it from the dict. + special_tokens_mask = batch.pop("special_tokens_mask", None) if self.mlm: - inputs, labels = self.mask_tokens(batch) - return {"input_ids": inputs, "labels": labels} + batch["input_ids"], batch["labels"] = self.mask_tokens( + batch["input_ids"], special_tokens_mask=special_tokens_mask + ) else: - labels = batch.clone().detach() + labels = batch["input_ids"].clone().detach() if self.tokenizer.pad_token_id is not None: labels[labels == self.tokenizer.pad_token_id] = -100 - return {"input_ids": batch, "labels": labels} - - def _tensorize_batch( - self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] - ) -> torch.Tensor: - # In order to accept both lists of lists and lists of Tensors - if isinstance(examples[0], (list, tuple)): - examples = [torch.tensor(e, dtype=torch.long) for e in examples] - length_of_first = examples[0].size(0) - are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) - if are_tensors_same_length: - return torch.stack(examples, dim=0) - else: - if self.tokenizer._pad_token is None: - raise ValueError( - "You are attempting to pad samples but the tokenizer you are using" - f" ({self.tokenizer.__class__.__name__}) does not have one." - ) - return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + batch["labels"] = labels + return batch - def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def mask_tokens( + self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ - - if self.tokenizer.mask_token is None: - raise ValueError( - "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." - ) - labels = inputs.clone() - # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) probability_matrix = torch.full(labels.shape, self.mlm_probability) - special_tokens_mask = [ - self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() - ] - probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) - if self.tokenizer._pad_token is not None: - padding_mask = labels.eq(self.tokenizer.pad_token_id) - probability_matrix.masked_fill_(padding_mask, value=0.0) + if special_tokens_mask is None: + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) + else: + special_tokens_mask = special_tokens_mask.bool() + + probability_matrix.masked_fill_(special_tokens_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens @@ -385,9 +425,16 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling): - preprocesses batches for both masked language modeling and sentence order prediction """ + def __init__(self, *args, **kwargs): + warnings.warn( + "DataCollatorForSOP is deprecated and will be removed in a future version, you can now use " + "DataCollatorForLanguageModeling instead.", + FutureWarning, + ) + def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: input_ids = [example["input_ids"] for example in examples] - input_ids = self._tensorize_batch(input_ids) + input_ids = _collate_batch(input_ids, self.tokenizer) input_ids, labels, attention_mask = self.mask_tokens(input_ids) token_type_ids = [example["token_type_ids"] for example in examples] @@ -582,136 +629,3 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, ) & masked_indices[i] return inputs.long(), perm_mask, target_mapping, labels.long() - - -@dataclass -class DataCollatorForNextSentencePrediction: - """ - Data collator used for next sentence prediction. - collates examples which contains pre-generated negative examples - - preprocesses batches for masked language modeling - """ - - tokenizer: PreTrainedTokenizerBase - mlm: bool = True - block_size: int = 512 - short_seq_probability: float = 0.1 - nsp_probability: float = 0.5 - mlm_probability: float = 0.15 - - def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: - """ - The input should contain negative examples, :class:`~transformers.DataCollatorForNextSentencePrediction` will - not generate any negative examples - - Args: - examples (:obj:`List[Dict]`): Each dictionary should have the following keys: - - - ``tokens_a``: A sequence of tokens, which should appear before ``tokens_b`` in the text. - - ``tokens_b``: A sequence of tokens, which should appear after ``tokens_a`` in the text. - - ``is_random_next``: 1 if this pair is generated randomly, else 0. - """ - - tokens_a = [e["tokens_a"] for e in examples] - tokens_b = [e["tokens_b"] for e in examples] - nsp_labels = [1 if e["is_random_next"] else 0 for e in examples] - - input_ids = [] - segment_ids = [] - attention_masks = [] - - assert len(tokens_a) == len(tokens_b) - for i in range(len(tokens_a)): - input_id, attention_mask, segment_id = self.create_features_from_example(tokens_a[i], tokens_b[i]) - input_ids.append(input_id) - segment_ids.append(segment_id) - attention_masks.append(attention_mask) - if self.mlm: - input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids)) - else: - input_ids = self._tensorize_batch(input_ids) - - result = { - "input_ids": input_ids, - "attention_mask": self._tensorize_batch(attention_masks), - "token_type_ids": self._tensorize_batch(segment_ids), - "labels": mlm_labels if self.mlm else None, - "next_sentence_label": torch.tensor(nsp_labels), - } - return result - - def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: - length_of_first = examples[0].size(0) - are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) - if are_tensors_same_length: - return torch.stack(examples, dim=0) - else: - if self.tokenizer._pad_token is None: - raise ValueError( - "You are attempting to pad samples but the tokenizer you are using" - f" ({self.tokenizer.__class__.__name__}) does not have one." - ) - return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) - - def create_features_from_example(self, tokens_a, tokens_b): - """Creates examples for a single document.""" - - max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True) - - tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences( - tokens_a, - tokens_b, - num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens, - truncation_strategy="longest_first", - ) - - input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) - attention_mask = [1] * len(input_id) - segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) - assert len(input_id) <= self.block_size - - # pad - while len(input_id) < self.block_size: - input_id.append(0) - attention_mask.append(0) - segment_id.append(0) - - input_id = torch.tensor(input_id) - attention_mask = torch.tensor(attention_mask) - segment_id = torch.tensor(segment_id) - - return input_id, attention_mask, segment_id - - def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. - """ - - if self.tokenizer.mask_token is None: - raise ValueError( - "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." - ) - - labels = inputs.clone() - # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) - probability_matrix = torch.full(labels.shape, self.mlm_probability) - special_tokens_mask = [ - self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() - ] - probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) - if self.tokenizer._pad_token is not None: - padding_mask = labels.eq(self.tokenizer.pad_token_id) - probability_matrix.masked_fill_(padding_mask, value=0.0) - masked_indices = torch.bernoulli(probability_matrix).bool() - labels[~masked_indices] = -100 # We only compute loss on masked tokens - - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices - inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) - - # 10% of the time, we replace masked input tokens with random word - indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced - random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) - inputs[indices_random] = random_words[indices_random] - - # The rest of the time (10% of the time) we keep the masked input tokens unchanged - return inputs, labels diff --git a/src/transformers/data/datasets/language_modeling.py b/src/transformers/data/datasets/language_modeling.py index 2c26747983be..33dfbef7bf05 100644 --- a/src/transformers/data/datasets/language_modeling.py +++ b/src/transformers/data/datasets/language_modeling.py @@ -3,6 +3,7 @@ import pickle import random import time +import warnings from typing import Dict, List, Optional import torch @@ -17,6 +18,11 @@ logger = logging.get_logger(__name__) +DEPRECATION_WARNING = ( + "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library." +) + + class TextDataset(Dataset): """ This will be superseded by a framework-agnostic approach soon. @@ -30,6 +36,7 @@ def __init__( overwrite_cache=False, cache_dir: Optional[str] = None, ): + warnings.warn(DEPRECATION_WARNING, FutureWarning) assert os.path.isfile(file_path), f"Input file path {file_path} not found" block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False) @@ -94,6 +101,7 @@ class LineByLineTextDataset(Dataset): """ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int): + warnings.warn(DEPRECATION_WARNING, FutureWarning) assert os.path.isfile(file_path), f"Input file path {file_path} not found" # Here, we do not cache the features, operating under the assumption # that we will soon use fast multithreaded tokenizers from the @@ -120,6 +128,7 @@ class LineByLineWithRefDataset(Dataset): """ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str): + warnings.warn(DEPRECATION_WARNING, FutureWarning) assert os.path.isfile(file_path), f"Input file path {file_path} not found" assert os.path.isfile(ref_path), f"Ref file path {file_path} not found" # Here, we do not cache the features, operating under the assumption @@ -156,6 +165,7 @@ class LineByLineWithSOPTextDataset(Dataset): """ def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int): + warnings.warn(DEPRECATION_WARNING, FutureWarning) assert os.path.isdir(file_dir) logger.info(f"Creating features from dataset file folder at {file_dir}") self.examples = [] @@ -305,6 +315,7 @@ def __init__( short_seq_probability=0.1, nsp_probability=0.5, ): + warnings.warn(DEPRECATION_WARNING, FutureWarning) assert os.path.isfile(file_path), f"Input file path {file_path} not found" self.block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True) @@ -449,9 +460,18 @@ def create_examples_from_document(self, document: List[List[int]], doc_index: in assert len(tokens_a) >= 1 assert len(tokens_b) >= 1 - self.examples.append( - {"tokens_a": tokens_a, "tokens_b": tokens_b, "is_random_next": is_random_next} - ) + # add special tokens + input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) + # add token type ids, 0 for sentence a, 1 for sentence b + token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) + + example = { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), + "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long), + } + + self.examples.append(example) current_chunk = [] current_length = 0 diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c6d70a53615a..b37290388126 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -26,11 +26,6 @@ def from_pretrained(self, *args, **kwargs): requires_pytorch(self) -class DataCollatorForNextSentencePrediction: - def __init__(self, *args, **kwargs): - requires_pytorch(self) - - class DataCollatorForPermutationLanguageModeling: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_data_collator.py b/tests/test_data_collator.py index d46a96589cab..d090b3eff285 100644 --- a/tests/test_data_collator.py +++ b/tests/test_data_collator.py @@ -12,9 +12,7 @@ from transformers import ( DataCollatorForLanguageModeling, - DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, - DataCollatorForSOP, DataCollatorForTokenClassification, DataCollatorWithPadding, default_data_collator, @@ -201,13 +199,16 @@ def test_plm(self): def test_nsp(self): tokenizer = BertTokenizer(self.vocab_file) - features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)] - data_collator = DataCollatorForNextSentencePrediction(tokenizer) + features = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + data_collator = DataCollatorForLanguageModeling(tokenizer) batch = data_collator(features) - self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) - self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512))) - self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 5))) self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,))) def test_sop(self): @@ -216,11 +217,11 @@ def test_sop(self): { "input_ids": torch.tensor([0, 1, 2, 3, 4]), "token_type_ids": torch.tensor([0, 1, 2, 3, 4]), - "sentence_order_label": torch.tensor(i), + "sentence_order_label": i, } for i in range(2) ] - data_collator = DataCollatorForSOP(tokenizer) + data_collator = DataCollatorForLanguageModeling(tokenizer) batch = data_collator(features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5))) From 439a87d0330139257ac4acadc6855ebde223f619 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 4 Nov 2020 17:07:03 -0500 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Lysandre Debut --- src/transformers/data/data_collator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index a634852252da..1ac14c5dbe46 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -192,7 +192,7 @@ def _collate_batch(examples, tokenizer): if tokenizer._pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" - f" ({tokenizer.__class__.__name__}) does not have one." + f" ({tokenizer.__class__.__name__}) does not have a pad token." ) # Creating the full tensor and filling it with our data. @@ -238,7 +238,7 @@ def __post_init__(self): if self.mlm and self.tokenizer.mask_token is None: raise ValueError( "This tokenizer does not have a mask token which is necessary for masked language modeling. " - "You should pass `mlm=False`." + "You should pass `mlm=False` to train on causal language modeling instead." ) def __call__( From 90fa02a5e3d99729127c02ea9038e25c1729a43c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 4 Nov 2020 17:20:03 -0500 Subject: [PATCH 3/3] Remove needless clone --- src/transformers/data/data_collator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 1ac14c5dbe46..5711845896f5 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -257,7 +257,7 @@ def __call__( batch["input_ids"], special_tokens_mask=special_tokens_mask ) else: - labels = batch["input_ids"].clone().detach() + labels = batch["input_ids"] if self.tokenizer.pad_token_id is not None: labels[labels == self.tokenizer.pad_token_id] = -100 batch["labels"] = labels