Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@

__all__ = ['DygraphToStaticAst']

DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']


def apply_optimization(transformers):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException
import warnings

import re

IGNORE_NAMES = [
'declarative', 'to_static', 'dygraph_to_static_func', 'wraps',
'staticmethod', 'classmethod'
'staticmethod', 'classmethod', 'decorator'
]


Expand Down Expand Up @@ -77,20 +78,35 @@ def visit_FunctionDef(self, node):
deco_name = deco.id
if deco_name in IGNORE_NAMES:
continue
elif deco_name == 'contextmanager':
warnings.warn(
"Dy2Static : A context manager decorator is used, this may not work correctly after transform."
)

# get function after decoration
deco_full_name = ast_to_source_code(deco).strip()
decoed_func = '_decoby_' + deco_name
decoed_func = '_decoedby_' + deco_name

# get function after decoration
if isinstance(deco, gast.Call):
# in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)',
deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
if '_jst.Call' in deco_full_name:
# in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)',
deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else:
# paddle api will not be transformed to '_jst.Call'
rematch = re.match(r'(.+?)\((.*)\)', deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)

else:
decofun_str = '{} = _jst.Call({})({})'.format(
decoed_func, deco_full_name, deco_target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import unittest
import numpy as np
import decos
import warnings
from functools import wraps
from contextlib import contextmanager


def deco1(func):
Expand Down Expand Up @@ -84,6 +86,14 @@ def inner_deco(*args, **kwargs):
return decorated(func)


def deco5():
return deco2


def deco6(x=0):
return deco2


@deco2
def fun1(x, y=0):
a = paddle.to_tensor(y)
Expand Down Expand Up @@ -114,7 +124,7 @@ def fun4(x, y=0):


@deco2
@deco4(x=5)
@deco4()
def fun5(x, y=0):
a = paddle.to_tensor(y)
print('in fun5, x=%d' % (x))
Expand All @@ -129,15 +139,55 @@ def fun6(x, y=0):
return a


@deco5()
def fun7(x, y=0):
a = paddle.to_tensor(y)
print('in fun7, x=%d' % (x))
return a


@deco6(2)
def fun8(x, y=0):
a = paddle.to_tensor(y)
print('in fun8, x=%d' % (x))
return a


@paddle.jit.to_static
def forward():
funcs = [fun1, fun2, fun3, fun4, fun5, fun6]
funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8]
out = []
for idx, fun in enumerate(funcs):
out.append(fun(idx + 1, idx + 1))
return out


@contextmanager
def contextmanager_warning():
yield


@contextmanager_warning()
def fun9():
print('in fun9 want contextmanager warning')


@paddle.jit.to_static
def warn1():
fun9()


@paddle.no_grad()
def fun10():
print('in fun10, paddle api decorated')
return True


@paddle.jit.to_static
def deco_with_paddle_api():
return fun10()


class TestDecoratorTransform(unittest.TestCase):

def test_deco_transform(self):
Expand All @@ -146,8 +196,27 @@ def test_deco_transform(self):
np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05)
np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05)
np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05)
np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05)
np.testing.assert_allclose(outs[4], np.array(7), rtol=1e-05)
np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05)

def test_contextmanager_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn1()
flag = False
for warn in w:
if (issubclass(warn.category, UserWarning)
) and "A context manager decorator is used" in str(
warn.message):
flag = True
break
self.assertTrue(flag)

def test_deco_with_paddle_api(self):
self.assertTrue(deco_with_paddle_api())


if __name__ == '__main__':
Expand Down