Skip to content

Commit d5a669f

Browse files
committed
Adapt test and underlying behavior
1 parent 0637f2e commit d5a669f

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
703703
return config_class.from_dict(config_dict, **kwargs)
704704
else:
705705
# Fallback: use pattern matching on the string.
706-
for pattern, config_class in CONFIG_MAPPING.items():
706+
# We go from longer names to shorter names to catch roberta before bert (for instance)
707+
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
707708
if pattern in str(pretrained_model_name_or_path):
708-
return config_class.from_dict(config_dict, **kwargs)
709+
return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs)
709710

710711
raise ValueError(
711712
f"Unrecognized model in {pretrained_model_name_or_path}. "

tests/models/auto/test_configuration_auto.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616
import importlib
17+
import json
18+
import os
1719
import sys
1820
import tempfile
1921
import unittest
@@ -56,14 +58,14 @@ def test_config_for_model_str(self):
5658
self.assertIsInstance(config, RobertaConfig)
5759

5860
def test_pattern_matching_fallback(self):
59-
"""
60-
In cases where config.json doesn't include a model_type,
61-
perform a few safety checks on the config mapping's order.
62-
"""
63-
# no key string should be included in a later key string (typical failure case)
64-
keys = list(CONFIG_MAPPING.keys())
65-
for i, key in enumerate(keys):
66-
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
61+
with tempfile.TemporaryDirectory() as tmp_dir:
62+
# This model name contains bert and roberta, but roberta ends up being picked.
63+
folder = os.path.join(tmp_dir, "fake-roberta")
64+
os.makedirs(folder, exist_ok=True)
65+
with open(os.path.join(folder, "config.json"), "w") as f:
66+
f.write(json.dumps({}))
67+
config = AutoConfig.from_pretrained(folder)
68+
self.assertEqual(type(config), RobertaConfig)
6769

6870
def test_new_config_registration(self):
6971
try:

0 commit comments

Comments
 (0)