Skip to content

Commit d42e96a

Browse files
authored
Use checkpoint in auto_class_docstring (#40844)
Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent 6eb3255 commit d42e96a

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/transformers/utils/auto_docstring.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ def contains_type(type_hint, target_type) -> tuple[bool, Optional[object]]:
10981098
if args == ():
10991099
try:
11001100
return issubclass(type_hint, target_type), type_hint
1101-
except Exception as _:
1101+
except Exception:
11021102
return issubclass(type(type_hint), target_type), type_hint
11031103
found_type_tuple = [contains_type(arg, target_type)[0] for arg in args]
11041104
found_type = any(found_type_tuple)
@@ -1112,6 +1112,8 @@ def get_model_name(obj):
11121112
Get the model name from the file path of the object.
11131113
"""
11141114
path = inspect.getsourcefile(obj)
1115+
if path is None:
1116+
return None
11151117
if path.split(os.path.sep)[-3] != "models":
11161118
return None
11171119
file_name = path.split(os.path.sep)[-1]
@@ -1783,9 +1785,10 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
17831785

17841786
is_dataclass = False
17851787
docstring_init = ""
1788+
docstring_args = ""
17861789
if "PreTrainedModel" in (x.__name__ for x in cls.__mro__):
17871790
docstring_init = auto_method_docstring(
1788-
cls.__init__, parent_class=cls, custom_args=custom_args
1791+
cls.__init__, parent_class=cls, custom_args=custom_args, checkpoint=checkpoint
17891792
).__doc__.replace("Args:", "Parameters:")
17901793
elif "ModelOutput" in (x.__name__ for x in cls.__mro__):
17911794
# We have a data class
@@ -1797,6 +1800,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
17971800
cls.__init__,
17981801
parent_class=cls,
17991802
custom_args=custom_args,
1803+
checkpoint=checkpoint,
18001804
source_args_dict=get_args_doc_from_source(ModelOutputArgs),
18011805
).__doc__
18021806
indent_level = get_indent_level(cls)
@@ -1836,7 +1840,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
18361840
docstring += docstring_args if docstring_args else "\nArgs:\n"
18371841
source_args_dict = get_args_doc_from_source(ModelOutputArgs)
18381842
doc_class = cls.__doc__ if cls.__doc__ else ""
1839-
documented_kwargs, _ = parse_docstring(doc_class)
1843+
documented_kwargs = parse_docstring(doc_class)[0]
18401844
for param_name, param_type_annotation in cls.__annotations__.items():
18411845
param_type = str(param_type_annotation)
18421846
optional = False

0 commit comments

Comments
 (0)