Skip to content

Commit a1661bc

Browse files
changjonathancstas00
authored andcommitted
Fix DataCollatorForWholeWordMask (huggingface#8379)
* Fix DataCollatorForWholeWordMask * Replace all tensorize_batch in data_collator.py
1 parent 59d6bc3 commit a1661bc

1 file changed

Lines changed: 3 additions & 21 deletions

File tree

src/transformers/data/data_collator.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def __call__(
315315
input_ids = examples
316316
examples = [{"input_ids": e} for e in examples]
317317

318-
batch_input = self._tensorize_batch(input_ids)
318+
batch_input = _collate_batch(input_ids, self.tokenizer)
319319

320320
mask_labels = []
321321
for e in examples:
@@ -332,7 +332,7 @@ def __call__(
332332
if i in ref_pos:
333333
ref_tokens[i] = "##" + ref_tokens[i]
334334
mask_labels.append(self._whole_word_mask(ref_tokens))
335-
batch_mask = self._tensorize_batch(mask_labels)
335+
batch_mask = _collate_batch(mask_labels, self.tokenizer)
336336
inputs, labels = self.mask_tokens(batch_input, batch_mask)
337337
return {"input_ids": inputs, "labels": labels}
338338

@@ -511,28 +511,10 @@ def __call__(
511511
) -> Dict[str, torch.Tensor]:
512512
if isinstance(examples[0], (dict, BatchEncoding)):
513513
examples = [e["input_ids"] for e in examples]
514-
batch = self._tensorize_batch(examples)
514+
batch = _collate_batch(examples, self.tokenizer)
515515
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
516516
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
517517

518-
def _tensorize_batch(
519-
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
520-
) -> torch.Tensor:
521-
# In order to accept both lists of lists and lists of Tensors
522-
if isinstance(examples[0], (list, tuple)):
523-
examples = [torch.Tensor(e) for e in examples]
524-
length_of_first = examples[0].size(0)
525-
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
526-
if are_tensors_same_length:
527-
return torch.stack(examples, dim=0)
528-
else:
529-
if self.tokenizer._pad_token is None:
530-
raise ValueError(
531-
"You are attempting to pad samples but the tokenizer you are using"
532-
f" ({self.tokenizer.__class__.__name__}) does not have one."
533-
)
534-
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
535-
536518
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
537519
"""
538520
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:

0 commit comments

Comments
 (0)