Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
Expand Down Expand Up @@ -84,6 +85,9 @@ def __init__(self, root_node):
self.condition_vars = defaultdict(set)
self.in_condition = False

# Some names are types, we shouldn't record them as loop var names.
self.type_vars = set()

self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
Expand Down Expand Up @@ -249,6 +253,18 @@ def visit_While(self, node):
self.generic_visit(node)
self.current_loop.pop()

def visit_Call(self, node):
# Store type var names such as "isinstance(x, some_type_names)" and
# Remove them later
if isinstance(node.func, gast.Name) and node.func.id == 'isinstance':
type_node = node.args[1]
if isinstance(type_node, gast.Tuple):
for element in type_node.elts:
self.type_vars.add(ast_to_source_code(element))
else:
self.type_vars.add(ast_to_source_code(type_node))
self.generic_visit(node)

def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
ret = set()
for node in node_set:
Expand Down Expand Up @@ -290,6 +306,7 @@ def _remove_unnecessary_vars(self, loop_vars, loop_node):
Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
2. Remove vars only in gast.comprehension.
3. Remove vars that are type names, for example: "isinstance(x, var_type_name)"
:param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
Expand Down Expand Up @@ -361,6 +378,12 @@ def _remove_unnecessary_vars(self, loop_vars, loop_node):
target_vars_of_for_node.add(var)

removed_vars = target_vars_of_for_node | vars_of_list_generator

# 3. Remove var type names which are stored in self.type_vars
for var in loop_vars:
if ast_to_source_code(var) in self.type_vars:
removed_vars.add(var)

return loop_vars - removed_vars


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import numpy as np
import unittest

import paddle
import paddle.nn as nn


class SimpleReturnLayer(nn.Layer):
def forward(self, x):
return x


class AddAttrLayer(nn.Layer):
def __init__(self):
super(AddAttrLayer, self).__init__()
self.attr = None

def forward(self, x):
out = x + self.attr
return out


class IsInstanceLayer(nn.Layer):
def __init__(self, layer):
super(IsInstanceLayer, self).__init__()
self.layer = layer

@paddle.jit.to_static
def forward(self, x):
if isinstance(self.layer, (AddAttrLayer, )):
self.layer.attr = x
res = self.layer(x)
return res


class SequentialLayer(nn.Layer):
def __init__(self, layers):
super(SequentialLayer, self).__init__()
self.layers = nn.LayerList(layers)

@paddle.jit.to_static
def forward(self, x):
res = x
for layer in self.layers:
if isinstance(layer, AddAttrLayer):
layer.attr = x
res = layer(res)
return res


def train(model, to_static):
prog_trans = paddle.jit.ProgramTranslator.get_instance()
prog_trans.enable(to_static)

x = paddle.ones(shape=[2, 3], dtype='int32')
out = model(x)

return out.numpy()


class TestIsinstance(unittest.TestCase):
def test_isinstance_simple_return_layer(self):
model = IsInstanceLayer(SimpleReturnLayer())
self._test_model(model)

def test_isinstance_add_attr_layer(self):
model = IsInstanceLayer(AddAttrLayer())
self._test_model(model)

def test_sequential_layer(self):
layers = []
for i in range(5):
layers.append(SimpleReturnLayer())
layers.append(AddAttrLayer())
model = SequentialLayer(layers)
self._test_model(model)

def _test_model(self, model):
st_out = train(model, to_static=True)
dy_out = train(model, to_static=False)
self.assertTrue(
np.allclose(dy_out, st_out),
msg="dy_out:\n {}\n st_out:\n{}".format(dy_out, st_out))


if __name__ == "__main__":
unittest.main()