diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 5711845896f5..04b55b7b6a33 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -315,7 +315,7 @@ def __call__( input_ids = examples examples = [{"input_ids": e} for e in examples] - batch_input = self._tensorize_batch(input_ids) + batch_input = _collate_batch(input_ids, self.tokenizer) mask_labels = [] for e in examples: @@ -332,7 +332,7 @@ def __call__( if i in ref_pos: ref_tokens[i] = "##" + ref_tokens[i] mask_labels.append(self._whole_word_mask(ref_tokens)) - batch_mask = self._tensorize_batch(mask_labels) + batch_mask = _collate_batch(mask_labels, self.tokenizer) inputs, labels = self.mask_tokens(batch_input, batch_mask) return {"input_ids": inputs, "labels": labels} @@ -511,28 +511,10 @@ def __call__( ) -> Dict[str, torch.Tensor]: if isinstance(examples[0], (dict, BatchEncoding)): examples = [e["input_ids"] for e in examples] - batch = self._tensorize_batch(examples) + batch = _collate_batch(examples, self.tokenizer) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} - def _tensorize_batch( - self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] - ) -> torch.Tensor: - # In order to accept both lists of lists and lists of Tensors - if isinstance(examples[0], (list, tuple)): - examples = [torch.Tensor(e) for e in examples] - length_of_first = examples[0].size(0) - are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) - if are_tensors_same_length: - return torch.stack(examples, dim=0) - else: - if self.tokenizer._pad_token is None: - raise ValueError( - "You are attempting to pad samples but the tokenizer you are using" - f" ({self.tokenizer.__class__.__name__}) does not have one." - ) - return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) - def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ The masked tokens to be predicted for a particular sequence are determined by the following algorithm: