diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index d2c90a2c852dbf..fc6dcdc3228bb7 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -30,7 +30,6 @@ unpack_by_structure as Unpack, ) from .program_translator import convert_to_static # noqa: F401 -from .static_analysis import StaticAnalysisVisitor # noqa: F401 from .transformers import DygraphToStaticAst # noqa: F401 from .utils import UndefinedVar, ast_to_source_code, saw # noqa: F401 from .variable_trans_func import ( # noqa: F401 diff --git a/python/paddle/jit/dy2static/static_analysis.py b/python/paddle/jit/dy2static/static_analysis.py deleted file mode 100644 index c239e8aaacf489..00000000000000 --- a/python/paddle/jit/dy2static/static_analysis.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle.utils import gast - -from .utils_helper import ( - binary_op_output_type, - index_in_list, - is_dygraph_api, - is_numpy_api, - is_paddle_api, - type_from_annotation, -) - -__all__ = [] - - -class AstNodeWrapper: - """ - Wrapper for python gast.node. We need a node wrapper because gast.node - doesn't store all required information when we are transforming AST. - We should collect additional information which the actual transformation - needs. - """ - - def __init__(self, node): - self.node = node - self.parent = None - self.children = [] - self.node_var_type = {"UNKNOWN"} - - -class StaticAnalysisVisitor: - """ - A class that does static analysis - """ - - def __init__(self, ast_root=None): - if ast_root is not None: - self.run(ast_root) - - def run(self, ast_root): - self.node_wrapper_root = None - self.ancestor_wrappers = [] - self.node_to_wrapper_map = {} - - self.dfs_visit(ast_root) - - def dfs_visit(self, node): - # AST reuses some gast.nodes, such as Param node of expr_context - if node not in self.node_to_wrapper_map: - cur_wrapper = AstNodeWrapper(node) - self.node_to_wrapper_map[node] = cur_wrapper - else: - cur_wrapper = self.node_to_wrapper_map[node] - - if self.node_wrapper_root is None: - self.node_wrapper_root = cur_wrapper - - if len(self.ancestor_wrappers) != 0: - last_wrapper = self.ancestor_wrappers[-1] - last_wrapper.children.append(cur_wrapper) - cur_wrapper.parent = last_wrapper - - self.ancestor_wrappers.append(cur_wrapper) - for child in gast.iter_child_nodes(node): - self.dfs_visit(child) - self.ancestor_wrappers.pop() - - cur_wrapper.node_var_type = self._get_node_var_type(cur_wrapper) - return cur_wrapper.node_var_type - - def get_node_wrapper_root(self): - return self.node_wrapper_root - - def get_node_to_wrapper_map(self): - return self.node_to_wrapper_map - - def is_tensor_node(self, node): - tensor_types = {"TENSOR", "PADDLE_RETURN_TYPES"} - node_wrapper = self.node_to_wrapper_map.get(node, None) - if node_wrapper is None: - return False - if node_wrapper.node_var_type & tensor_types: - return True - - def _get_constant_node_type(self, node): - assert isinstance(node, gast.Constant), ( - "Type of input node should be gast.Constant, but received %s" - % type(node) - ) - # singleton: None, True or False - if node.value is None: - return {"NONE"} - if isinstance(node.value, bool): - return {"BOOLEAN"} - if isinstance(node.value, int): - return {"INT"} - if isinstance(node.value, float): - return {"FLOAT"} - if isinstance(node.value, str): - return {"STRING"} - - return {"UNKNOWN"} - - def _get_node_var_type(self, cur_wrapper): - node = cur_wrapper.node - if isinstance(node, gast.Constant): - return self._get_constant_node_type(node) - - if isinstance(node, gast.BoolOp): - return {"BOOLEAN"} - if isinstance(node, gast.Compare): - return {"BOOLEAN"} - - if isinstance(node, gast.Dict): - return {"DICT"} - if isinstance(node, gast.Set): - return {"SET"} - - if isinstance(node, gast.UnaryOp): - return self.node_to_wrapper_map[node.operand].node_var_type - - if isinstance(node, gast.BinOp): - left_type = self.node_to_wrapper_map[node.left].node_var_type - right_type = self.node_to_wrapper_map[node.right].node_var_type - result_type = set() - for l in left_type: - for r in right_type: - result_type.add(binary_op_output_type(l, r)) - return result_type - - if isinstance(node, gast.Assign): - ret_type = self.node_to_wrapper_map[node.value].node_var_type - for target in node.targets: - if isinstance(target, gast.Name): - self.node_to_wrapper_map[target].node_var_type = ret_type - # Handle statements like `a, b = paddle.shape(x)` - elif isinstance(target, gast.Tuple): - for sub_target in target.elts: - if isinstance(sub_target, gast.Name): - self.node_to_wrapper_map[ - sub_target - ].node_var_type = ret_type - return ret_type - - if isinstance(node, gast.AnnAssign): - # TODO(0x45f): To determine whether need to support assignment statements - # like `self.x: float = 2.1`. - ret_type = {type_from_annotation(node.annotation)} - # if annotation and value(Constant) are diffent type, we use value type - if node.value: - node_value_type = self.node_to_wrapper_map[ - node.value - ].node_var_type - if not (node_value_type & {"UNKNOWN", "STATEMENT"}): - ret_type = node_value_type - if isinstance(node.target, gast.Name): - self.node_to_wrapper_map[node.target].node_var_type = ret_type - return ret_type - - if isinstance(node, gast.Name): - if node.id == "None": - return {"NONE"} - if node.id in {"True", "False"}: - return {"BOOLEAN"} - # If node is child of functionDef.arguments - parent_node_wrapper = cur_wrapper.parent - if parent_node_wrapper and isinstance( - parent_node_wrapper.node, gast.arguments - ): - return self._get_func_argument_type(parent_node_wrapper, node) - - return {"UNKNOWN"} - - if isinstance(node, gast.Return): - # If return nothing: - if node.value is None: - return {"NONE"} - - return {"UNKNOWN"} - - if isinstance(node, gast.Call): - if is_dygraph_api(node): - if isinstance(node.func, gast.Attribute): - if node.func.attr == "to_variable": - return {"TENSOR"} - if is_paddle_api(node): - return {"PADDLE_RETURN_TYPES"} - if is_numpy_api(node): - # In this simple version we assume numpy api returns nd-array - return {"NUMPY_NDARRAY"} - - if isinstance(node.func, gast.Name): - return {"UNKNOWN"} - if isinstance(node, gast.Subscript): - if self.is_tensor_node(node.value): - return {"TENSOR"} - - return {"STATEMENT"} - - def _get_func_argument_type(self, parent_node_wrapper, node): - """ - Returns type information by parsing annotation or default values. - - For example: - 1. parse by default values. - foo(x, y=1, z='s') -> x: UNKNOWN, y: INT, z: STR - - 2. parse by Py3 type annotation. - foo(x: Tensor, y: int, z: str) -> x: Tensor, y: INT, z: STR - - 3. parse by type annotation and default values. - foo(x: Tensor, y: int, z: str = 'abc') -> x: Tensor, y: INT, z: STR - - NOTE: Currently, we only support Tensor, int, bool, float, str et.al. - Other complicate types will be supported later. - """ - assert isinstance(node, gast.Name) - - parent_node = parent_node_wrapper.node - var_type = {"UNKNOWN"} - if node.annotation is not None: - var_type = {type_from_annotation(node.annotation)} - - # if annotation and value(Constant) are diffent type, we use value type - if parent_node.defaults: - index = index_in_list(parent_node.args, node) - args_len = len(parent_node.args) - if index != -1 and args_len - index <= len(parent_node.defaults): - defaults_node = parent_node.defaults[index - args_len] - if isinstance(defaults_node, gast.Constant): - var_type = self._get_constant_node_type(defaults_node) - - return var_type diff --git a/python/paddle/jit/dy2static/transformers/decorator_transformer.py b/python/paddle/jit/dy2static/transformers/decorator_transformer.py index 46415926266560..143d1fb1e14d7d 100644 --- a/python/paddle/jit/dy2static/transformers/decorator_transformer.py +++ b/python/paddle/jit/dy2static/transformers/decorator_transformer.py @@ -41,8 +41,6 @@ class DecoratorTransformer(BaseTransformer): def __init__(self, root): self.root = root - self.ancestor_nodes = [] - def transform(self): """ Main function to transform AST. diff --git a/python/paddle/jit/dy2static/transformers/loop_transformer.py b/python/paddle/jit/dy2static/transformers/loop_transformer.py index 2d2cfee1f97b0a..52d503b68bab60 100644 --- a/python/paddle/jit/dy2static/transformers/loop_transformer.py +++ b/python/paddle/jit/dy2static/transformers/loop_transformer.py @@ -18,7 +18,6 @@ from paddle.base import unique_name from paddle.utils import gast -from ..static_analysis import StaticAnalysisVisitor from ..utils import ( FOR_BODY_PREFIX, FOR_CONDITION_PREFIX, @@ -32,6 +31,7 @@ create_nonlocal_stmt_nodes, create_set_args_node, get_attribute_full_name, + get_parent_mapping, ) from .base import ( BaseTransformer, @@ -137,10 +137,7 @@ def __init__(self, root_node): # Some names are types, we shouldn't record them as loop var names. self.type_vars = set() - self.static_analysis_visitor = StaticAnalysisVisitor(root_node) - self.node_to_wrapper_map = ( - self.static_analysis_visitor.get_node_to_wrapper_map() - ) + self.to_parent_mapping = get_parent_mapping(root_node) self.visit(root_node) @@ -184,10 +181,6 @@ def get_loop_var_names(self, node): write_vars = self.write_in_loop[node] write_names = self._var_nodes_to_names(write_vars) - name_to_type = {} - for var in in_loop_vars: - wrapper = self.node_to_wrapper_map[var] - name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type for name in in_loop_name_strs: if name in before_loop_name_strs: # If a variable is used in loop and created before loop @@ -363,12 +356,7 @@ def _is_ancestor_node(self, ancestor_node, node): return False def _get_parent_node(self, node): - wrapper_node = self.node_to_wrapper_map.get(node) - if wrapper_node: - if wrapper_node.parent: - parent_node = wrapper_node.parent.node - return parent_node - return None + return self.to_parent_mapping.get(node) def _remove_unnecessary_vars(self, loop_vars, loop_node): """ diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 361613d12efeaf..b050f45d65885f 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import atexit import builtins import copy @@ -47,7 +49,6 @@ index_in_list, is_api_in_module, is_dygraph_api, - is_numpy_api, is_paddle_api, ) @@ -119,6 +120,14 @@ def visit(self, node): return ret +def get_parent_mapping(root): + to_parent: dict[gast.AST, gast.AST] = {} + for node in gast.walk(root): + for child in gast.iter_child_nodes(node): + to_parent[child] = node + return to_parent + + dygraph_class_to_static_api = { "CosineDecay": "cosine_decay", "ExponentialDecay": "exponential_decay", diff --git a/python/paddle/jit/dy2static/utils_helper.py b/python/paddle/jit/dy2static/utils_helper.py index 9a55f23cf46db4..cec81940a72a3a 100644 --- a/python/paddle/jit/dy2static/utils_helper.py +++ b/python/paddle/jit/dy2static/utils_helper.py @@ -24,7 +24,6 @@ from paddle.utils import gast from .ast_utils import ast_to_source_code -from .logging_utils import warn def index_in_list(array_list, item): @@ -43,6 +42,7 @@ def index_in_list(array_list, item): def is_dygraph_api(node): + # TODO(SigureMo): Cleanup this function after we remove the BasicApiTransformer # Note: A api in module dygraph_to_static is not a real dygraph api. if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX): return False @@ -76,76 +76,10 @@ def _is_api_in_module_helper(obj, module_prefix): return m is not None and m.__name__.startswith(module_prefix) -# Is numpy_api cannot reuse is_api_in_module because of numpy module problem -def is_numpy_api(node): - assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" - func_str = ast_to_source_code(node.func) - try: - module_result = eval( - "_is_api_in_module_helper({}, '{}')".format(func_str, "numpy") - ) - # BUG: np.random.uniform doesn't have module and cannot be analyzed - # TODO: find a better way - return module_result or ( - func_str.startswith("numpy.") or func_str.startswith("np.") - ) - except Exception: - return False - - def is_paddle_api(node): return is_api_in_module(node, PADDLE_MODULE_PREFIX) -def binary_op_output_type(in_type1, in_type2): - if in_type1 == in_type2: - return in_type1 - - if in_type1 == "UNKNOWN": - return in_type2 - if in_type2 == "UNKNOWN": - return in_type1 - - supported_types = [ - "BOOLEAN", - "INT", - "FLOAT", - "NUMPY_NDARRAY", - "TENSOR", - "PADDLE_RETURN_TYPES", - ] - - if in_type1 not in supported_types: - return "UNKNOWN" - if in_type2 not in supported_types: - return "UNKNOWN" - - forbidden_types = ["NUMPY_NDARRAY", "TENSOR"] - if in_type1 in forbidden_types and in_type2 in forbidden_types: - return "UNKNOWN" - return max(in_type1, in_type2) - - -Annotation_map = { - "Tensor": "TENSOR", - "paddle.Tensor": "TENSOR", - "int": "INT", - "float": "FLOAT", - "bool": "BOOLEAN", - "str": "STRING", -} - - -def type_from_annotation(annotation): - annotation_str = ast_to_source_code(annotation).strip() - if annotation_str in Annotation_map: - return Annotation_map[annotation_str] - - # raise warning if not found - warn("Currently we don't support annotation: %s" % annotation_str) - return "UNKNOWN" - - def set_dynamic_shape(variable, shape_list): if paddle.base.dygraph.base.in_to_static_mode(): if isinstance(variable, paddle.base.framework.Variable): diff --git a/python/paddle/jit/dy2static/variable_trans_func.py b/python/paddle/jit/dy2static/variable_trans_func.py index b32001dd28f7b8..13bec054823158 100644 --- a/python/paddle/jit/dy2static/variable_trans_func.py +++ b/python/paddle/jit/dy2static/variable_trans_func.py @@ -26,20 +26,6 @@ def create_undefined_var(name): return gast.parse(func_code).body[0] -def create_fill_constant_node(name, value=0): - func_code = f"{name} = paddle.full(shape=[1], " - if isinstance(value, bool): - func_code += f"dtype='bool', fill_value={value}, name='{name}')" - return gast.parse(func_code).body[0] - if isinstance(value, float): - func_code += f"dtype='float64', fill_value={value}, name='{name}')" - return gast.parse(func_code).body[0] - - if isinstance(value, int): - func_code += f"dtype='int64', fill_value={value}, name='{name}')" - return gast.parse(func_code).body[0] - - def to_static_variable(x): ''' Translate a Python Tensor to PaddlePaddle static graph Tensor diff --git a/test/dygraph_to_static/test_static_analysis.py b/test/dygraph_to_static/test_static_analysis.py deleted file mode 100644 index 889bf183d079c0..00000000000000 --- a/test/dygraph_to_static/test_static_analysis.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import unittest - -import numpy as np - -import paddle -from paddle import base -from paddle.jit.dy2static import StaticAnalysisVisitor -from paddle.utils import gast - - -def func_to_test1(a, b): - return a + b - - -result_var_type1 = {} - - -def func_to_test2(x): - for i in range(10): - x += i - m = 3 - while m < 8: - m += 1 - if x < 0: - return 0 - else: - return x - - -result_var_type2 = {'m': {"INT"}} - - -def func_to_test3(): - a = 1 - b = 3.0 - c = a * b - d = True + c - e = a < b - f = 9 * (a * 4) - g = "dddy" - h = None - i = False - j = None + 1 - k: float = 1.0 - l: paddle.Tensor = paddle.to_tensor([1, 2]) - - -result_var_type3 = { - 'a': {"INT"}, - 'b': {"FLOAT"}, - 'c': {"FLOAT"}, - 'd': {"FLOAT"}, - 'e': {"BOOLEAN"}, - 'f': {"INT"}, - 'g': {"STRING"}, - 'h': {"NONE"}, - 'i': {"BOOLEAN"}, - 'j': {"UNKNOWN"}, - 'k': {"FLOAT"}, - 'l': {"PADDLE_RETURN_TYPES"}, -} - - -def func_to_test4(): - with base.dygraph.guard(): - a = np.random.uniform(0.1, 1, [1, 2]) - b = 1 + a - c = base.dygraph.to_variable(b) - d = (c + 1) * 0.3 - - -result_var_type4 = { - 'a': {"NUMPY_NDARRAY"}, - 'b': {"NUMPY_NDARRAY"}, - 'c': {"TENSOR"}, - 'd': {"TENSOR"}, -} - - -def func_to_test5(): - def inner_int_func(): - return 1 - - def inner_bool_float_func(x): - a = 1.0 - if x > 0: - return a - return False - - def inner_unknown_func(x): - return x - - a = inner_int_func() - b = inner_bool_float_func(3) - c = inner_unknown_func(None) - d = paddle.static.data('x', [1, 2]) - - -result_var_type5 = { - 'a': {"INT"}, - 'b': {"FLOAT", "BOOLEAN"}, - 'c': {"UNKNOWN"}, - 'd': {"PADDLE_RETURN_TYPES"}, - 'inner_int_func': {"INT"}, - 'inner_bool_float_func': {"FLOAT", "BOOLEAN"}, - 'inner_unknown_func': {"UNKNOWN"}, -} - - -def func_to_test6(x, y=1): - i = base.dygraph.to_variable(x) - - def add(x, y): - return x + y - - while x < 10: - i = add(i, x) - x = x + y - - return i - - -result_var_type6 = { - 'i': {"INT"}, - 'x': {"INT"}, - 'y': {"INT"}, - 'add': {"INT"}, -} - - -def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float = 'diff'): - a = True - e, f = paddle.shape(c) - g: paddle.Tensor = len(c) - - -result_var_type7 = { - 'a': {"BOOLEAN"}, - 'b': {"FLOAT"}, - 'c': {"TENSOR"}, - 'd': {"STRING"}, - 'e': {"PADDLE_RETURN_TYPES"}, - 'f': {"PADDLE_RETURN_TYPES"}, - 'g': {"TENSOR"}, -} - -test_funcs = [ - func_to_test1, - func_to_test2, - func_to_test3, - func_to_test4, - func_to_test5, - func_to_test6, - func_to_test7, -] -result_var_type = [ - result_var_type1, - result_var_type2, - result_var_type3, - result_var_type4, - result_var_type5, - result_var_type6, - result_var_type7, -] - - -class TestStaticAnalysis(unittest.TestCase): - def _check_wrapper(self, wrapper, node_to_wrapper_map): - self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper) - if wrapper.parent is not None: - self.assertTrue(wrapper in wrapper.parent.children) - - children_ast_nodes = list(gast.iter_child_nodes(wrapper.node)) - self.assertEqual(len(wrapper.children), len(children_ast_nodes)) - for child in wrapper.children: - self.assertTrue(child.node in children_ast_nodes) - self._check_wrapper(child, node_to_wrapper_map) - - def test_construct_node_wrapper(self): - for func in test_funcs: - test_source_code = inspect.getsource(func) - ast_root = gast.parse(test_source_code) - visitor = StaticAnalysisVisitor(ast_root) - wrapper_root = visitor.get_node_wrapper_root() - node_to_wrapper_map = visitor.get_node_to_wrapper_map() - self._check_wrapper(wrapper_root, node_to_wrapper_map) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/dygraph_to_static/test_tensor_hook.py b/test/dygraph_to_static/test_tensor_hook.py index a2867665b5acb6..e17094e24b3aa7 100644 --- a/test/dygraph_to_static/test_tensor_hook.py +++ b/test/dygraph_to_static/test_tensor_hook.py @@ -25,7 +25,7 @@ from paddle.jit import to_static -class TestStaticAnalysis(Dy2StTestBase): +class TestTensorHook(Dy2StTestBase): def test_hook_for_different_parameter(self): def f(x): def h(g): diff --git a/test/dygraph_to_static/test_variable_trans_func.py b/test/dygraph_to_static/test_variable_trans_func.py deleted file mode 100644 index 4cb451cc510238..00000000000000 --- a/test/dygraph_to_static/test_variable_trans_func.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pir - -from paddle.jit.dy2static.utils import ast_to_source_code -from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node - - -class TestVariableTransFunc(Dy2StTestBase): - @test_legacy_and_pir - def test_create_fill_constant_node(self): - node = create_fill_constant_node("a", 1.0) - source = "a = paddle.full(shape=[1], dtype='float64', fill_value=1.0, name='a')" - self.assertEqual( - ast_to_source_code(node).replace('\n', '').replace(' ', ''), - source.replace(' ', ''), - ) - - node = create_fill_constant_node("b", True) - source = "b = paddle.full(shape=[1], dtype='bool', fill_value=True, name='b')" - self.assertEqual( - ast_to_source_code(node).replace('\n', '').replace(' ', ''), - source.replace(' ', ''), - ) - - node = create_fill_constant_node("c", 4293) - source = "c = paddle.full(shape=[1], dtype='int64', fill_value=4293, name='c')" - self.assertEqual( - ast_to_source_code(node).replace('\n', '').replace(' ', ''), - source.replace(' ', ''), - ) - - self.assertIsNone(create_fill_constant_node("e", None)) - self.assertIsNone(create_fill_constant_node("e", [])) - - -if __name__ == '__main__': - unittest.main()