Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 21 additions & 174 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,201 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import inspect
from typing import Any

import paddle

from .base.dygraph.generated_tensor_methods_patch import methods_map

# Add docstr for some C++ functions in paddle
_add_docstr = paddle.base.core.eager._add_docstr
_code_template = R"""
from __future__ import annotations

def _parse_function_signature(
func_str: str,
) -> tuple[inspect.Signature, str, dict]:
"""
Return the inspect.Signaturn for Python function and string signature
such as "(x,axis=None)" for builtin_function
"""
func_str = func_str.strip()

if not func_str.startswith('def '):
func_str = 'def ' + func_str

# Create a complete function
full_def = func_str + ":\n pass"

try:
# Parse AST
module = ast.parse(full_def)
func_def = next(
node for node in module.body if isinstance(node, ast.FunctionDef)
)
except Exception as e:
raise ValueError(f"Failed to parse function definition: {e}") from e

builtin_annotations_dict = {}

# Get return annotation
return_annotation = inspect.Signature.empty
if func_def.returns:
return_annotation = _ast_unparse(func_def.returns)
if return_annotation is not inspect.Signature.empty:
builtin_annotations_dict.update({"return": str(return_annotation)})

builtin_sig_str = "("
# Create parameters
parameters = []
count = 0

# Process the POSITIONAL_OR_KEYWORD parameters
for param in func_def.args.posonlyargs + func_def.args.args:
param_name = param.arg
builtin_param_str = param_name

annotation = inspect.Parameter.empty
if param.annotation:
annotation = _ast_unparse(param.annotation)
builtin_annotations_dict.update({param_name: str(annotation)})
# Get Default value
default = inspect.Parameter.empty

if func_def.args.defaults and len(func_def.args.defaults) > (
len(func_def.args.args) - len(func_def.args.defaults)
):

idx = count - (
len(func_def.args.args) - len(func_def.args.defaults)
)
if idx >= 0:
default_node = func_def.args.defaults[idx]
default = _ast_literal_eval(default_node)
builtin_param_str += " = " + str(default)

# Create inspect.Parameter
param_obj = inspect.Parameter(
name=param_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=default,
annotation=annotation,
)
builtin_sig_str += f"{builtin_param_str},"

count += 1
parameters.append(param_obj)

# Process the key word only params such as out
count = 0
if len(func_def.args.kwonlyargs) > 0:
builtin_sig_str += "*,"
for param in func_def.args.kwonlyargs:
para_name = param.arg
builtin_param_str = param_name
annotation = (
_ast_unparse(param.annotation)
if param.annotation
else inspect.Parameter.empty
)
if param.annotation:
builtin_annotations_dict.update({param_name: str(annotation)})
idx = count
default = inspect.Parameter.empty
if idx >= 0 and idx < len(func_def.args.kw_defaults):
default_node = func_def.args.kw_defaults[idx]
default = _ast_literal_eval(default_node)
builtin_param_str += " = " + str(default)
parameters.append(
inspect.Parameter(
name=para_name,
kind=inspect.Parameter.KEYWORD_ONLY,
default=default,
annotation=annotation,
)
)
builtin_sig_str += f"{builtin_param_str}"
count += 1

builtin_sig_str += ")"
# Create inspect.Signature and return builtin_sig_str
return (
inspect.Signature(
parameters=parameters, return_annotation=return_annotation
),
builtin_sig_str,
builtin_annotations_dict,
)


def _ast_unparse(node: ast.AST) -> str:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Subscript):
value = _ast_unparse(node.value)
slice_str = _ast_unparse(node.slice)
return f"{value}[{slice_str}]"
elif isinstance(node, ast.Index):
return _ast_unparse(node.value)
elif isinstance(node, ast.Constant):
# process string
if isinstance(node.value, str):
return f"'{node.value}'"
return str(node.value)
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = _ast_unparse(node.left)
right = _ast_unparse(node.right)
return f"{left} | {right}"
elif isinstance(node, ast.Attribute):
return f"{_ast_unparse(node.value)}.{node.attr}"
elif isinstance(node, ast.Tuple):
return ", ".join(_ast_unparse(el) for el in node.elts)
else:
return ast.dump(node)


def _ast_literal_eval(node: ast.AST) -> Any:
"""Eval and transpose AST node to Python literal"""
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.NameConstant):
return node.value
elif isinstance(node, ast.Num):
return node.n
elif isinstance(node, ast.Str):
return node.s
elif isinstance(node, ast.Name) and node.id == "None":
return None
elif isinstance(node, ast.Name) and node.id == "True":
return True
elif isinstance(node, ast.Name) and node.id == "False":
return False
else:
raise ValueError(f"Unsupported default value: {ast.dump(node)}")
{}:
...

"""

# Add docstr for some C++ functions in paddle
_add_docstr = paddle.base.core.eager._add_docstr

def _parse_function_signature(func_name: str, code: str) -> inspect.Signature:
code = _code_template.format(code.strip())
code_obj = compile(code, "<string>", "exec")
globals = {}
eval(code_obj, globals)
return inspect.signature(globals[func_name])


def add_doc_and_signature(method: str, docstr: str, signature: str) -> None:
def add_doc_and_signature(func_name: str, docstr: str, func_def: str) -> None:
"""
Add docstr for function (paddle.*) and method (paddle.Tensor.*) if method exists
"""
# builtin_sig = "(a,b=1,c=0)"
python_api_sig, builtin_sig, builtin_ann = _parse_function_signature(
signature
)
python_api_sig = _parse_function_signature(func_name, func_def)
for module in [paddle, paddle.Tensor]:
if hasattr(module, method):
func = getattr(module, method)
if hasattr(module, func_name):
func = getattr(module, func_name)
if inspect.isfunction(func):
func.__doc__ = docstr
elif inspect.ismethod(func):
func.__self__.__doc__ = docstr
elif inspect.isbuiltin(func):
_add_docstr(func, docstr, builtin_sig, builtin_ann)
_add_docstr(func, docstr)
methods_dict = dict(methods_map)
if method in methods_dict.keys():
tensor_func = methods_dict[method]
if func_name in methods_dict.keys():
tensor_func = methods_dict[func_name]
tensor_func.__signature__ = python_api_sig


Expand Down
Loading