@@ -315,7 +315,7 @@ def __call__(
315315 input_ids = examples
316316 examples = [{"input_ids" : e } for e in examples ]
317317
318- batch_input = self . _tensorize_batch (input_ids )
318+ batch_input = _collate_batch (input_ids , self . tokenizer )
319319
320320 mask_labels = []
321321 for e in examples :
@@ -332,7 +332,7 @@ def __call__(
332332 if i in ref_pos :
333333 ref_tokens [i ] = "##" + ref_tokens [i ]
334334 mask_labels .append (self ._whole_word_mask (ref_tokens ))
335- batch_mask = self . _tensorize_batch (mask_labels )
335+ batch_mask = _collate_batch (mask_labels , self . tokenizer )
336336 inputs , labels = self .mask_tokens (batch_input , batch_mask )
337337 return {"input_ids" : inputs , "labels" : labels }
338338
@@ -511,28 +511,10 @@ def __call__(
511511 ) -> Dict [str , torch .Tensor ]:
512512 if isinstance (examples [0 ], (dict , BatchEncoding )):
513513 examples = [e ["input_ids" ] for e in examples ]
514- batch = self . _tensorize_batch (examples )
514+ batch = _collate_batch (examples , self . tokenizer )
515515 inputs , perm_mask , target_mapping , labels = self .mask_tokens (batch )
516516 return {"input_ids" : inputs , "perm_mask" : perm_mask , "target_mapping" : target_mapping , "labels" : labels }
517517
518- def _tensorize_batch (
519- self , examples : List [Union [List [int ], torch .Tensor , Dict [str , torch .Tensor ]]]
520- ) -> torch .Tensor :
521- # In order to accept both lists of lists and lists of Tensors
522- if isinstance (examples [0 ], (list , tuple )):
523- examples = [torch .Tensor (e ) for e in examples ]
524- length_of_first = examples [0 ].size (0 )
525- are_tensors_same_length = all (x .size (0 ) == length_of_first for x in examples )
526- if are_tensors_same_length :
527- return torch .stack (examples , dim = 0 )
528- else :
529- if self .tokenizer ._pad_token is None :
530- raise ValueError (
531- "You are attempting to pad samples but the tokenizer you are using"
532- f" ({ self .tokenizer .__class__ .__name__ } ) does not have one."
533- )
534- return pad_sequence (examples , batch_first = True , padding_value = self .tokenizer .pad_token_id )
535-
536518 def mask_tokens (self , inputs : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
537519 """
538520 The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0 commit comments