|
27 | 27 | _is_punctuation, |
28 | 28 | _is_whitespace, |
29 | 29 | ) |
30 | | -from transformers.testing_utils import require_tokenizers, slow |
| 30 | +from transformers.testing_utils import require_tokenizers, require_torch, slow |
31 | 31 |
|
32 | | -from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english |
| 32 | +from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings |
33 | 33 |
|
34 | 34 |
|
35 | 35 | # Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert |
@@ -338,3 +338,47 @@ def test_change_tokenize_chinese_chars(self): |
338 | 338 | ] |
339 | 339 | self.assertListEqual(tokens_without_spe_char_p, expected_tokens) |
340 | 340 | self.assertListEqual(tokens_without_spe_char_r, expected_tokens) |
| 341 | + |
| 342 | + # RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input |
| 343 | + @require_torch |
| 344 | + @slow |
| 345 | + def test_torch_encode_plus_sent_to_model(self): |
| 346 | + import torch |
| 347 | + |
| 348 | + from transformers import MODEL_MAPPING, TOKENIZER_MAPPING |
| 349 | + |
| 350 | + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING) |
| 351 | + |
| 352 | + tokenizers = self.get_tokenizers(do_lower_case=False) |
| 353 | + for tokenizer in tokenizers: |
| 354 | + with self.subTest(f"{tokenizer.__class__.__name__}"): |
| 355 | + |
| 356 | + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: |
| 357 | + return |
| 358 | + |
| 359 | + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] |
| 360 | + config = config_class() |
| 361 | + |
| 362 | + if config.is_encoder_decoder or config.pad_token_id is None: |
| 363 | + return |
| 364 | + |
| 365 | + model = model_class(config) |
| 366 | + |
| 367 | + # The following test is different from the common's one |
| 368 | + self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer)) |
| 369 | + |
| 370 | + # Build sequence |
| 371 | + first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] |
| 372 | + sequence = " ".join(first_ten_tokens) |
| 373 | + encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt") |
| 374 | + |
| 375 | + # Ensure that the BatchEncoding.to() method works. |
| 376 | + encoded_sequence.to(model.device) |
| 377 | + |
| 378 | + batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") |
| 379 | + # This should not fail |
| 380 | + |
| 381 | + with torch.no_grad(): # saves some time |
| 382 | + # The following lines are different from the common's ones |
| 383 | + model.embed_questions(**encoded_sequence) |
| 384 | + model.embed_questions(**batch_encoded_sequence) |
0 commit comments