Skip to content

Commit ddb1a47

Browse files
authored
Automatically sort auto mappings (#17250)
* Automatically sort auto mappings * Better class extraction * Some auto class magic * Adapt test and underlying behavior * Remove re-used config * Quality
1 parent 2f611f8 commit ddb1a47

File tree

14 files changed

+1096
-991
lines changed

14 files changed

+1096
-991
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ jobs:
857857
- run: black --check --preview examples tests src utils
858858
- run: isort --check-only examples tests src utils
859859
- run: python utils/custom_init_isort.py --check_only
860+
- run: python utils/sort_auto_mappings.py --check_only
860861
- run: flake8 examples tests src utils
861862
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
862863

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ quality:
4848
black --check --preview $(check_dirs)
4949
isort --check-only $(check_dirs)
5050
python utils/custom_init_isort.py --check_only
51+
python utils/sort_auto_mappings.py --check_only
5152
flake8 $(check_dirs)
5253
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
5354

5455
# Format source code automatically and check is there are any problems left that need manual fixing
5556

5657
extra_style_checks:
5758
python utils/custom_init_isort.py
59+
python utils/sort_auto_mappings.py
5860
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
5961

6062
# this target runs checks on all files and potentially modifies some of them

docs/source/en/index.mdx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow.
259259
| Swin | | | | | |
260260
| T5 | | | | | |
261261
| TAPAS | | | | | |
262-
| TAPEX | | | | | |
263262
| Transformer-XL | | | | | |
264263
| TrOCR | | | | | |
265264
| UniSpeech | | | | | |

docs/source/en/serialization.mdx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ Ready-made configurations include the following architectures:
7474
- RoBERTa
7575
- RoFormer
7676
- T5
77-
- TAPEX
7877
- ViT
7978
- XLM-RoBERTa
8079
- XLM-RoBERTa-XL

src/transformers/models/auto/auto_factory.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,17 @@ def __getitem__(self, key):
560560
if key in self._extra_content:
561561
return self._extra_content[key]
562562
model_type = self._reverse_config_mapping[key.__name__]
563-
if model_type not in self._model_mapping:
564-
raise KeyError(key)
565-
model_name = self._model_mapping[model_type]
566-
return self._load_attr_from_module(model_type, model_name)
563+
if model_type in self._model_mapping:
564+
model_name = self._model_mapping[model_type]
565+
return self._load_attr_from_module(model_type, model_name)
566+
567+
# Maybe there was several model types associated with this config.
568+
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
569+
for mtype in model_types:
570+
if mtype in self._model_mapping:
571+
model_name = self._model_mapping[mtype]
572+
return self._load_attr_from_module(mtype, model_name)
573+
raise KeyError(key)
567574

568575
def _load_attr_from_module(self, model_type, attr):
569576
module_name = model_type_to_module_name(model_type)

src/transformers/models/auto/configuration_auto.py

Lines changed: 272 additions & 272 deletions
Large diffs are not rendered by default.

src/transformers/models/auto/feature_extraction_auto.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,30 @@
3838
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
3939
[
4040
("beit", "BeitFeatureExtractor"),
41-
("detr", "DetrFeatureExtractor"),
42-
("deit", "DeiTFeatureExtractor"),
43-
("hubert", "Wav2Vec2FeatureExtractor"),
44-
("speech_to_text", "Speech2TextFeatureExtractor"),
45-
("vit", "ViTFeatureExtractor"),
46-
("wav2vec2", "Wav2Vec2FeatureExtractor"),
47-
("detr", "DetrFeatureExtractor"),
48-
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
4941
("clip", "CLIPFeatureExtractor"),
50-
("flava", "FlavaFeatureExtractor"),
51-
("perceiver", "PerceiverFeatureExtractor"),
52-
("swin", "ViTFeatureExtractor"),
53-
("vit_mae", "ViTFeatureExtractor"),
54-
("segformer", "SegformerFeatureExtractor"),
5542
("convnext", "ConvNextFeatureExtractor"),
56-
("van", "ConvNextFeatureExtractor"),
57-
("resnet", "ConvNextFeatureExtractor"),
58-
("regnet", "ConvNextFeatureExtractor"),
59-
("poolformer", "PoolFormerFeatureExtractor"),
60-
("maskformer", "MaskFormerFeatureExtractor"),
6143
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
6244
("data2vec-vision", "BeitFeatureExtractor"),
45+
("deit", "DeiTFeatureExtractor"),
46+
("detr", "DetrFeatureExtractor"),
47+
("detr", "DetrFeatureExtractor"),
6348
("dpt", "DPTFeatureExtractor"),
49+
("flava", "FlavaFeatureExtractor"),
6450
("glpn", "GLPNFeatureExtractor"),
51+
("hubert", "Wav2Vec2FeatureExtractor"),
52+
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
53+
("maskformer", "MaskFormerFeatureExtractor"),
54+
("perceiver", "PerceiverFeatureExtractor"),
55+
("poolformer", "PoolFormerFeatureExtractor"),
56+
("regnet", "ConvNextFeatureExtractor"),
57+
("resnet", "ConvNextFeatureExtractor"),
58+
("segformer", "SegformerFeatureExtractor"),
59+
("speech_to_text", "Speech2TextFeatureExtractor"),
60+
("swin", "ViTFeatureExtractor"),
61+
("van", "ConvNextFeatureExtractor"),
62+
("vit", "ViTFeatureExtractor"),
63+
("vit_mae", "ViTFeatureExtractor"),
64+
("wav2vec2", "Wav2Vec2FeatureExtractor"),
6565
("yolos", "YolosFeatureExtractor"),
6666
]
6767
)
@@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str):
7575
module_name = model_type_to_module_name(module_name)
7676

7777
module = importlib.import_module(f".{module_name}", "transformers.models")
78-
return getattr(module, class_name)
79-
break
78+
try:
79+
return getattr(module, class_name)
80+
except AttributeError:
81+
continue
8082

8183
for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
8284
if getattr(extractor, "__name__", None) == class_name:

0 commit comments

Comments
 (0)