@@ -1098,7 +1098,7 @@ def contains_type(type_hint, target_type) -> tuple[bool, Optional[object]]:
10981098 if args == ():
10991099 try :
11001100 return issubclass (type_hint , target_type ), type_hint
1101- except Exception as _ :
1101+ except Exception :
11021102 return issubclass (type (type_hint ), target_type ), type_hint
11031103 found_type_tuple = [contains_type (arg , target_type )[0 ] for arg in args ]
11041104 found_type = any (found_type_tuple )
@@ -1112,6 +1112,8 @@ def get_model_name(obj):
11121112 Get the model name from the file path of the object.
11131113 """
11141114 path = inspect .getsourcefile (obj )
1115+ if path is None :
1116+ return None
11151117 if path .split (os .path .sep )[- 3 ] != "models" :
11161118 return None
11171119 file_name = path .split (os .path .sep )[- 1 ]
@@ -1783,9 +1785,10 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
17831785
17841786 is_dataclass = False
17851787 docstring_init = ""
1788+ docstring_args = ""
17861789 if "PreTrainedModel" in (x .__name__ for x in cls .__mro__ ):
17871790 docstring_init = auto_method_docstring (
1788- cls .__init__ , parent_class = cls , custom_args = custom_args
1791+ cls .__init__ , parent_class = cls , custom_args = custom_args , checkpoint = checkpoint
17891792 ).__doc__ .replace ("Args:" , "Parameters:" )
17901793 elif "ModelOutput" in (x .__name__ for x in cls .__mro__ ):
17911794 # We have a data class
@@ -1797,6 +1800,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
17971800 cls .__init__ ,
17981801 parent_class = cls ,
17991802 custom_args = custom_args ,
1803+ checkpoint = checkpoint ,
18001804 source_args_dict = get_args_doc_from_source (ModelOutputArgs ),
18011805 ).__doc__
18021806 indent_level = get_indent_level (cls )
@@ -1836,7 +1840,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
18361840 docstring += docstring_args if docstring_args else "\n Args:\n "
18371841 source_args_dict = get_args_doc_from_source (ModelOutputArgs )
18381842 doc_class = cls .__doc__ if cls .__doc__ else ""
1839- documented_kwargs , _ = parse_docstring (doc_class )
1843+ documented_kwargs = parse_docstring (doc_class )[ 0 ]
18401844 for param_name , param_type_annotation in cls .__annotations__ .items ():
18411845 param_type = str (param_type_annotation )
18421846 optional = False
0 commit comments