Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/

import os
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
Expand All @@ -38,15 +39,14 @@
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.decorator_transformer import DecoratorTransformer

from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name

__all__ = ['DygraphToStaticAst']

DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']


def apply_optimization(transformers):
"""
Expand Down Expand Up @@ -105,6 +105,7 @@ def transfer_from_node_type(self, node_wrapper):
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
GradTransformer, # transform paddle.grad to paddle.gradients
DecoratorTransformer, # transform decorators to function call
]

apply_optimization(transformers)
Expand All @@ -120,30 +121,6 @@ def visit_FunctionDef(self, node):
self.decorate_func_name = node.name

self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = []
ignore_list = ["staticmethod"]
for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id in ignore_list:
continue
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ d.id + " in " + self.decorate_func_name)
if isinstance(d, gast.Attribute):
full_attribute_name = get_attribute_full_name(d)
has_translate_decorator = False
for deco in DECORATOR_NAMES:
if deco in full_attribute_name:
has_translate_decorator = True
break
if not has_translate_decorator:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ full_attribute_name + " in " +
self.decorate_func_name)
node.decorator_list = decorator_list
return node

def get_module_name(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer

__all__ = ["convert_call"]
Expand Down Expand Up @@ -206,8 +206,9 @@ def dyfunc(x):
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
_origfunc = unwrap(func)
global_functions = set()
for fn in func.__globals__.values():
for fn in _origfunc.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticFunction):
Expand Down
134 changes: 134 additions & 0 deletions python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException
import warnings

import re

IGNORE_NAMES = [
'declarative', 'to_static', 'dygraph_to_static_func', 'wraps',
'staticmethod', 'classmethod', 'decorator'
]


class DecoratorTransformer(BaseTransformer):
"""
Transform decorators.
"""

def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node

self.ancestor_nodes = []

def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)

def visit_FunctionDef(self, node):
assert isinstance(node, gast.FunctionDef)
self.generic_visit(node)

deco_list = node.decorator_list
node.decorator_list = []

# every decorator will append a node
decofun_nodes = []
# func to be decoed next time
deco_target = '_orig_' + node.name
# last decoed func
decoed_func = ''

for deco in reversed(deco_list):
# skip INGNORE_NAMES
if isinstance(deco, gast.Attribute):
deco_name = deco.attr
elif isinstance(deco, gast.Call):
if hasattr(deco.func, 'args'):
deco_name = deco.func.args[0].id
elif hasattr(deco.func, 'attr'):
deco_name = deco.func.attr
else:
deco_name = deco.func.id
else:
deco_name = deco.id
if deco_name in IGNORE_NAMES:
continue
elif deco_name == 'contextmanager':
warnings.warn(
"Dy2Static : A context manager decorator is used, this may not work correctly after transform."
)

deco_full_name = ast_to_source_code(deco).strip()
decoed_func = '_decoedby_' + deco_name

# get function after decoration
if isinstance(deco, gast.Call):
if '_jst.Call' in deco_full_name:
# in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)',
deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else:
# paddle api will not be transformed to '_jst.Call'
rematch = re.match(r'(.+?)\((.*)\)', deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)

else:
decofun_str = '{} = _jst.Call({})({})'.format(
decoed_func, deco_full_name, deco_target)

decofun_nodes.extend(gast.parse(decofun_str).body)
deco_target = decoed_func

if not decofun_nodes:
return node

orig_func_node = gast.FunctionDef(name='_orig_' + node.name,
args=node.args,
body=node.body,
decorator_list=[],
returns=None,
type_comment=None)

args = [arg.id for arg in node.args.args]
arg_str = ','.join(args)
callfun_str = 'return {}({})'.format(decoed_func, arg_str)
callfun_node = gast.parse(callfun_str).body[0]

node.body = [orig_func_node] + decofun_nodes + [callfun_node]

return node
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,20 @@ def visit_FunctionDef(self, node):

# Prepend no value placeholders
self.function_def.pop()

# Need update self.pre_analysis after pop
# For fix this case:
'''
def fun(cond):
def inner():
pass
if cond:
return True
else:
return False
'''
if self.function_def:
self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1])
return node

def visit_Return(self, node):
Expand Down
46 changes: 46 additions & 0 deletions python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
import paddle

from functools import wraps


def deco1(fun):

@wraps(fun)
def inner(*args, **kwargs):
print('in decos.deco1, added 1')
_t = paddle.to_tensor([1])
_tt = fun(*args, **kwargs)
return paddle.add(_t, _tt)

return inner


def deco2(x=0):

def inner_deco(func):

@wraps(func)
def inner(*args, **kwargs):
print('in decos.deco2, added {}'.format(x))
_t = paddle.to_tensor(x)
_tt = func(*args, **kwargs)
return paddle.add(_t, _tt)

return inner

return inner_deco
Loading