Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/paddle/jit/dy2static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
unpack_by_structure as Unpack,
)
from .program_translator import convert_to_static # noqa: F401
from .static_analysis import StaticAnalysisVisitor # noqa: F401
from .transformers import DygraphToStaticAst # noqa: F401
from .utils import UndefinedVar, ast_to_source_code, saw # noqa: F401
from .variable_trans_func import ( # noqa: F401
Expand Down
246 changes: 0 additions & 246 deletions python/paddle/jit/dy2static/static_analysis.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ class DecoratorTransformer(BaseTransformer):
def __init__(self, root):
self.root = root

self.ancestor_nodes = []

def transform(self):
"""
Main function to transform AST.
Expand Down
18 changes: 3 additions & 15 deletions python/paddle/jit/dy2static/transformers/loop_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from paddle.base import unique_name
from paddle.utils import gast

from ..static_analysis import StaticAnalysisVisitor
from ..utils import (
FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX,
Expand All @@ -32,6 +31,7 @@
create_nonlocal_stmt_nodes,
create_set_args_node,
get_attribute_full_name,
get_parent_mapping,
)
from .base import (
BaseTransformer,
Expand Down Expand Up @@ -137,10 +137,7 @@ def __init__(self, root_node):
# Some names are types, we shouldn't record them as loop var names.
self.type_vars = set()

self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
self.node_to_wrapper_map = (
self.static_analysis_visitor.get_node_to_wrapper_map()
)
self.to_parent_mapping = get_parent_mapping(root_node)

self.visit(root_node)

Expand Down Expand Up @@ -184,10 +181,6 @@ def get_loop_var_names(self, node):
write_vars = self.write_in_loop[node]
write_names = self._var_nodes_to_names(write_vars)

name_to_type = {}
for var in in_loop_vars:
wrapper = self.node_to_wrapper_map[var]
name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
for name in in_loop_name_strs:
if name in before_loop_name_strs:
# If a variable is used in loop and created before loop
Expand Down Expand Up @@ -363,12 +356,7 @@ def _is_ancestor_node(self, ancestor_node, node):
return False

def _get_parent_node(self, node):
wrapper_node = self.node_to_wrapper_map.get(node)
if wrapper_node:
if wrapper_node.parent:
parent_node = wrapper_node.parent.node
return parent_node
return None
return self.to_parent_mapping.get(node)

def _remove_unnecessary_vars(self, loop_vars, loop_node):
"""
Expand Down
11 changes: 10 additions & 1 deletion python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import atexit
import builtins
import copy
Expand Down Expand Up @@ -47,7 +49,6 @@
index_in_list,
is_api_in_module,
is_dygraph_api,
is_numpy_api,
is_paddle_api,
)

Expand Down Expand Up @@ -119,6 +120,14 @@ def visit(self, node):
return ret


def get_parent_mapping(root):
to_parent: dict[gast.AST, gast.AST] = {}
for node in gast.walk(root):
for child in gast.iter_child_nodes(node):
to_parent[child] = node
return to_parent


dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
"ExponentialDecay": "exponential_decay",
Expand Down
Loading