From f2ab245408c7bce1c962af56e0cc5c112d9525a3 Mon Sep 17 00:00:00 2001 From: feifei-111 Date: Tue, 13 Sep 2022 19:10:56 +0800 Subject: [PATCH 1/2] [dy2static] support user to use decorator in their program (#45768) * support deco * fix deco ast type * arg_str * 1 * support callable deco * code style * codestyle * test_error * fix decos in another file * recover conflict codes --- .../dygraph_to_static/ast_transformer.py | 27 +-- .../dygraph_to_static/convert_call_func.py | 5 +- .../decorator_transformer.py | 118 ++++++++++++++ .../dygraph_to_static/return_transformer.py | 14 ++ .../unittests/dygraph_to_static/decos.py | 46 ++++++ .../test_decorator_transform.py | 154 ++++++++++++++++++ .../unittests/dygraph_to_static/test_error.py | 10 -- 7 files changed, 338 insertions(+), 36 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index fd146d77632ca1..b936c47b511358 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -18,6 +18,7 @@ # It provides a compatibility layer between the AST of various Python versions, # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ + import os from paddle.utils import gast from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer @@ -38,6 +39,7 @@ from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer +from paddle.fluid.dygraph.dygraph_to_static.decorator_transformer import DecoratorTransformer from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code @@ -105,6 +107,7 @@ def transfer_from_node_type(self, node_wrapper): CallTransformer, # transform call recursively CastTransformer, # type casting statement GradTransformer, # transform paddle.grad to paddle.gradients + DecoratorTransformer, # transform decorators to function call ] apply_optimization(transformers) @@ -120,30 +123,6 @@ def visit_FunctionDef(self, node): self.decorate_func_name = node.name self.generic_visit(node) - # Remove the decorated name of dygraph_to_static - if hasattr(node, 'decorator_list'): - decorator_list = [] - ignore_list = ["staticmethod"] - for d in node.decorator_list: - if isinstance(d, gast.Name) and d.id in ignore_list: - continue - if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: - raise NotImplementedError( - "ProgramTranslator hasn't implemented multiple decorators. Please remove " - + d.id + " in " + self.decorate_func_name) - if isinstance(d, gast.Attribute): - full_attribute_name = get_attribute_full_name(d) - has_translate_decorator = False - for deco in DECORATOR_NAMES: - if deco in full_attribute_name: - has_translate_decorator = True - break - if not has_translate_decorator: - raise NotImplementedError( - "ProgramTranslator hasn't implemented multiple decorators. Please remove " - + full_attribute_name + " in " + - self.decorate_func_name) - node.decorator_list = decorator_list return node def get_module_name(self): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index fda668dc7455f2..a3d96b6fe0ad86 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -33,7 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators -from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap from paddle.fluid.dygraph.layers import Layer __all__ = ["convert_call"] @@ -206,8 +206,9 @@ def dyfunc(x): # `foo` will be converted into a wrapper class, suppose as `StaticFunction`. # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of # `foo` function. So `isinstance(fn, StaticFunction)` is added here. + _origfunc = unwrap(func) global_functions = set() - for fn in func.__globals__.values(): + for fn in _origfunc.__globals__.values(): if inspect.isfunction(fn): global_functions.add(fn) elif isinstance(fn, StaticFunction): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py new file mode 100644 index 00000000000000..ab193b674c25c4 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.utils import gast +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code + +import re + +IGNORE_NAMES = [ + 'declarative', 'to_static', 'dygraph_to_static_func', 'wraps', + 'staticmethod', 'classmethod' +] + + +class DecoratorTransformer(BaseTransformer): + """ + Transform decorators. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Type of input node should be AstNodeWrapper, but received %s ." % type( + wrapper_root) + self.root = wrapper_root.node + + self.ancestor_nodes = [] + + def transform(self): + """ + Main function to transform AST. + """ + self.visit(self.root) + + def visit_FunctionDef(self, node): + assert isinstance(node, gast.FunctionDef) + self.generic_visit(node) + + deco_list = node.decorator_list + node.decorator_list = [] + + # every decorator will append a node + decofun_nodes = [] + # func to be decoed next time + deco_target = '_orig_' + node.name + # last decoed func + decoed_func = '' + + for deco in reversed(deco_list): + # skip INGNORE_NAMES + if isinstance(deco, gast.Attribute): + deco_name = deco.attr + elif isinstance(deco, gast.Call): + if hasattr(deco.func, 'args'): + deco_name = deco.func.args[0].id + elif hasattr(deco.func, 'attr'): + deco_name = deco.func.attr + else: + deco_name = deco.func.id + else: + deco_name = deco.id + if deco_name in IGNORE_NAMES: + continue + + # get function after decoration + deco_full_name = ast_to_source_code(deco).strip() + decoed_func = '_decoby_' + deco_name + if isinstance(deco, gast.Call): + # in this case , the deco_full_name will be like: + # '_jst.Call(deco)(5)' + rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)', + deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + else: + decofun_str = '{} = _jst.Call({})({})'.format( + decoed_func, deco_full_name, deco_target) + + decofun_nodes.extend(gast.parse(decofun_str).body) + deco_target = decoed_func + + if not decofun_nodes: + return node + + orig_func_node = gast.FunctionDef(name='_orig_' + node.name, + args=node.args, + body=node.body, + decorator_list=[], + returns=None, + type_comment=None) + + args = [arg.id for arg in node.args.args] + arg_str = ','.join(args) + callfun_str = 'return {}({})'.format(decoed_func, arg_str) + callfun_node = gast.parse(callfun_str).body[0] + + node.body = [orig_func_node] + decofun_nodes + [callfun_node] + + return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index 3eadd455e1033e..ed2a739936e1e3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -228,6 +228,20 @@ def visit_FunctionDef(self, node): # Prepend no value placeholders self.function_def.pop() + + # Need update self.pre_analysis after pop + # For fix this case: + ''' + def fun(cond): + def inner(): + pass + if cond: + return True + else: + return False + ''' + if self.function_def: + self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1]) return node def visit_Return(self, node): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py new file mode 100644 index 00000000000000..6e3333c15a0ce1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import paddle + +from functools import wraps + + +def deco1(fun): + + @wraps(fun) + def inner(*args, **kwargs): + print('in decos.deco1, added 1') + _t = paddle.to_tensor([1]) + _tt = fun(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + +def deco2(x=0): + + def inner_deco(func): + + @wraps(func) + def inner(*args, **kwargs): + print('in decos.deco2, added {}'.format(x)) + _t = paddle.to_tensor(x) + _tt = func(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + return inner_deco diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py new file mode 100644 index 00000000000000..c6c2750e307939 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py @@ -0,0 +1,154 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import unittest +import numpy as np +import decos +from functools import wraps + + +def deco1(func): + + @wraps(func) + def inner(*args, **kwargs): + print('in deco1, added 1') + _x = 2 + if (_x < 1): + _x += 1 + else: + _x -= 1 + _t = paddle.to_tensor([1]) + _tt = func(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + +def deco2(fun): + + @wraps(fun) + def inner(*args, **kwargs): + print('in deco2, added 2') + _t = paddle.to_tensor([2]) + _tt = fun(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + +def deco3(x=3): + + def inner_deco(func): + + @wraps(func) + def inner(*args, **kwargs): + print('in deco3, added {}'.format(x)) + _t = paddle.to_tensor(x) + _tt = func(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + return inner_deco + + +def deco4(func=None, x=0): + + def decorated(pyfunc): + + @wraps(pyfunc) + def inner_deco(*args, **kwargs): + print('in deco4, added {}'.format(x)) + _t = paddle.to_tensor(x) + _tt = pyfunc(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner_deco + + if func == None: + return decorated + return decorated(func) + + +@deco2 +def fun1(x, y=0): + a = paddle.to_tensor(y) + print('in fun1, x=%d' % (x)) + return a + + +@deco1 +@deco2 +def fun2(x, y=0): + a = paddle.to_tensor(y) + print('in fun2, x=%d' % (x)) + return a + + +@deco3(3) +def fun3(x, y=0): + a = paddle.to_tensor(y) + print('in fun3, x=%d' % (x)) + return a + + +@deco4(x=4) +def fun4(x, y=0): + a = paddle.to_tensor(y) + print('in fun4, x=%d' % (x)) + return a + + +@deco2 +@deco4(x=5) +def fun5(x, y=0): + a = paddle.to_tensor(y) + print('in fun5, x=%d' % (x)) + return a + + +@decos.deco1 +@decos.deco2(2) +def fun6(x, y=0): + a = paddle.to_tensor(y) + print('in fun6, x=%d' % (x)) + return a + + +@paddle.jit.to_static +def forward(): + funcs = [fun1, fun2, fun3, fun4, fun5, fun6] + out = [] + for idx, fun in enumerate(funcs): + out.append(fun(idx + 1, idx + 1)) + return out + + +class TestDecoratorTransform(unittest.TestCase): + + def test_deco_transform(self): + outs = forward() + np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05) + np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) + np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) + np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05) + np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05) + np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 27d7389b903cc4..97f0cf99b5f65d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -399,16 +399,6 @@ def test_error(self): # # Situation 4: NotImplementedError -class TestErrorInOther(unittest.TestCase): - def test(self): - paddle.disable_static() - prog_trans = paddle.jit.ProgramTranslator() - with self.assertRaises(NotImplementedError): - prog_trans.get_output(func_decorated_by_other_1) - - with self.assertRaises(NotImplementedError): - func_decorated_by_other_2() - class TestSuggestionErrorInRuntime(TestErrorBase): def set_func(self): From 5ad5deb8e63971f1580be1fd19d0f6ebf371c201 Mon Sep 17 00:00:00 2001 From: feifei-111 Date: Sat, 17 Sep 2022 11:41:53 +0800 Subject: [PATCH 2/2] [BugFix] fixed a bug in decorator transformer, it can not analyze decorator with params correctly (#46055) * fix deco call * add raise * add test * add warn, fix paddle api * fix error type * fix coverage --- .../dygraph_to_static/ast_transformer.py | 2 - .../decorator_transformer.py | 42 +++++++---- .../test_decorator_transform.py | 75 ++++++++++++++++++- 3 files changed, 101 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index b936c47b511358..e045348e6c942a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -47,8 +47,6 @@ __all__ = ['DygraphToStaticAst'] -DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func'] - def apply_optimization(transformers): """ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py index ab193b674c25c4..8442403e04c83e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py @@ -18,13 +18,14 @@ from paddle.utils import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer -from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException +import warnings import re IGNORE_NAMES = [ 'declarative', 'to_static', 'dygraph_to_static_func', 'wraps', - 'staticmethod', 'classmethod' + 'staticmethod', 'classmethod', 'decorator' ] @@ -77,20 +78,35 @@ def visit_FunctionDef(self, node): deco_name = deco.id if deco_name in IGNORE_NAMES: continue + elif deco_name == 'contextmanager': + warnings.warn( + "Dy2Static : A context manager decorator is used, this may not work correctly after transform." + ) - # get function after decoration deco_full_name = ast_to_source_code(deco).strip() - decoed_func = '_decoby_' + deco_name + decoed_func = '_decoedby_' + deco_name + + # get function after decoration if isinstance(deco, gast.Call): - # in this case , the deco_full_name will be like: - # '_jst.Call(deco)(5)' - rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)', - deco_full_name) - re_name = rematch.group(1) - re_args = rematch.group(2) - re_args_with_func = deco_target + ', ' + re_args - decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ - .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + if '_jst.Call' in deco_full_name: + # in this case , the deco_full_name will be like: + # '_jst.Call(deco)(5)' + rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)', + deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + else: + # paddle api will not be transformed to '_jst.Call' + rematch = re.match(r'(.+?)\((.*)\)', deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + else: decofun_str = '{} = _jst.Call({})({})'.format( decoed_func, deco_full_name, deco_target) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py index c6c2750e307939..4acc789a451bb0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py @@ -18,7 +18,9 @@ import unittest import numpy as np import decos +import warnings from functools import wraps +from contextlib import contextmanager def deco1(func): @@ -84,6 +86,14 @@ def inner_deco(*args, **kwargs): return decorated(func) +def deco5(): + return deco2 + + +def deco6(x=0): + return deco2 + + @deco2 def fun1(x, y=0): a = paddle.to_tensor(y) @@ -114,7 +124,7 @@ def fun4(x, y=0): @deco2 -@deco4(x=5) +@deco4() def fun5(x, y=0): a = paddle.to_tensor(y) print('in fun5, x=%d' % (x)) @@ -129,15 +139,55 @@ def fun6(x, y=0): return a +@deco5() +def fun7(x, y=0): + a = paddle.to_tensor(y) + print('in fun7, x=%d' % (x)) + return a + + +@deco6(2) +def fun8(x, y=0): + a = paddle.to_tensor(y) + print('in fun8, x=%d' % (x)) + return a + + @paddle.jit.to_static def forward(): - funcs = [fun1, fun2, fun3, fun4, fun5, fun6] + funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8] out = [] for idx, fun in enumerate(funcs): out.append(fun(idx + 1, idx + 1)) return out +@contextmanager +def contextmanager_warning(): + yield + + +@contextmanager_warning() +def fun9(): + print('in fun9 want contextmanager warning') + + +@paddle.jit.to_static +def warn1(): + fun9() + + +@paddle.no_grad() +def fun10(): + print('in fun10, paddle api decorated') + return True + + +@paddle.jit.to_static +def deco_with_paddle_api(): + return fun10() + + class TestDecoratorTransform(unittest.TestCase): def test_deco_transform(self): @@ -146,8 +196,27 @@ def test_deco_transform(self): np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05) - np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05) + np.testing.assert_allclose(outs[4], np.array(7), rtol=1e-05) np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05) + np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) + np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) + + def test_contextmanager_warning(self): + paddle.disable_static() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn1() + flag = False + for warn in w: + if (issubclass(warn.category, UserWarning) + ) and "A context manager decorator is used" in str( + warn.message): + flag = True + break + self.assertTrue(flag) + + def test_deco_with_paddle_api(self): + self.assertTrue(deco_with_paddle_api()) if __name__ == '__main__':