From 60c379897b3adb653009ed4fac28be0708db538f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sun, 8 Nov 2020 09:43:45 +0800 Subject: [PATCH] Fix DataCollatorForWholeWordMask again --- src/transformers/data/data_collator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 04b55b7b6a33..ba94baaa7d0c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -206,6 +206,10 @@ def _collate_batch(examples, tokenizer): return result +def tolist(x: Union[List[Any], torch.Tensor]): + return x.tolist() if isinstance(x, torch.Tensor) else x + + @dataclass class DataCollatorForLanguageModeling: """ @@ -320,13 +324,13 @@ def __call__( mask_labels = [] for e in examples: ref_tokens = [] - for id in e["input_ids"].tolist(): + for id in tolist(e["input_ids"]): token = self.tokenizer._convert_id_to_token(id) ref_tokens.append(token) # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] if "chinese_ref" in e: - ref_pos = e["chinese_ref"].tolist() + ref_pos = tolist(e["chinese_ref"]) len_seq = e["input_ids"].size(0) for i in range(len_seq): if i in ref_pos: