Skip to content

Commit 753876d

Browse files
SaulLuelusenji
authored andcommitted
fix retribert's test_torch_encode_plus_sent_to_model (huggingface#17231)
1 parent a2fdc29 commit 753876d

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

tests/models/retribert/test_tokenization_retribert.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
_is_punctuation,
2828
_is_whitespace,
2929
)
30-
from transformers.testing_utils import require_tokenizers, slow
30+
from transformers.testing_utils import require_tokenizers, require_torch, slow
3131

32-
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
32+
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
3333

3434

3535
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
@@ -338,3 +338,47 @@ def test_change_tokenize_chinese_chars(self):
338338
]
339339
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
340340
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
341+
342+
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
343+
@require_torch
344+
@slow
345+
def test_torch_encode_plus_sent_to_model(self):
346+
import torch
347+
348+
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
349+
350+
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
351+
352+
tokenizers = self.get_tokenizers(do_lower_case=False)
353+
for tokenizer in tokenizers:
354+
with self.subTest(f"{tokenizer.__class__.__name__}"):
355+
356+
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
357+
return
358+
359+
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
360+
config = config_class()
361+
362+
if config.is_encoder_decoder or config.pad_token_id is None:
363+
return
364+
365+
model = model_class(config)
366+
367+
# The following test is different from the common's one
368+
self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
369+
370+
# Build sequence
371+
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
372+
sequence = " ".join(first_ten_tokens)
373+
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
374+
375+
# Ensure that the BatchEncoding.to() method works.
376+
encoded_sequence.to(model.device)
377+
378+
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
379+
# This should not fail
380+
381+
with torch.no_grad(): # saves some time
382+
# The following lines are different from the common's ones
383+
model.embed_questions(**encoded_sequence)
384+
model.embed_questions(**batch_encoded_sequence)

0 commit comments

Comments
 (0)