@@ -55,6 +55,28 @@ class TimmWrapperModelOutput(ModelOutput):
5555 attentions : Optional [tuple [torch .FloatTensor , ...]] = None
5656
5757
58+ def _create_timm_model_with_error_handling (config : "TimmWrapperConfig" , ** model_kwargs ):
59+ """
60+ Creates a timm model and provides a clear error message if the model is not found,
61+ suggesting a library update.
62+ """
63+ try :
64+ model = timm .create_model (
65+ config .architecture ,
66+ pretrained = False ,
67+ ** model_kwargs ,
68+ )
69+ return model
70+ except RuntimeError as e :
71+ if "Unknown model" in str (e ):
72+ # A good general check for unknown models.
73+ raise ImportError (
74+ f"The model architecture '{ config .architecture } ' is not supported in your version of timm ({ timm .__version__ } ). "
75+ "Please upgrade timm to a more recent version with `pip install -U timm`."
76+ ) from e
77+ raise e
78+
79+
5880@auto_docstring
5981class TimmWrapperPreTrainedModel (PreTrainedModel ):
6082 main_input_name = "pixel_values"
@@ -138,7 +160,7 @@ def __init__(self, config: TimmWrapperConfig):
138160 super ().__init__ (config )
139161 # using num_classes=0 to avoid creating classification head
140162 extra_init_kwargs = config .model_args or {}
141- self .timm_model = timm . create_model (config . architecture , pretrained = False , num_classes = 0 , ** extra_init_kwargs )
163+ self .timm_model = _create_timm_model_with_error_handling (config , num_classes = 0 , ** extra_init_kwargs )
142164 self .post_init ()
143165
144166 @auto_docstring
@@ -254,8 +276,8 @@ def __init__(self, config: TimmWrapperConfig):
254276 )
255277
256278 extra_init_kwargs = config .model_args or {}
257- self .timm_model = timm . create_model (
258- config . architecture , pretrained = False , num_classes = config .num_labels , ** extra_init_kwargs
279+ self .timm_model = _create_timm_model_with_error_handling (
280+ config , num_classes = config .num_labels , ** extra_init_kwargs
259281 )
260282 self .num_labels = config .num_labels
261283 self .post_init ()
0 commit comments