Skip to content

Commit 135b62a

Browse files
authored
[Dy2stat] Refine code of DygraphToStaticAst (#28103)
* refine code of DygraphToStaticAst * add __init__ function
1 parent 6dd64b0 commit 135b62a

File tree

1 file changed

+29
-60
lines changed

1 file changed

+29
-60
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py

Lines changed: 29 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
4747
Main class to transform Dygraph to Static Graph
4848
"""
4949

50+
def __init__(self):
51+
self.translator_logger = logging_utils.TranslatorLogger()
52+
5053
def get_static_ast(self, root):
5154
# save root for some analysis may need global AST
5255
self.root = root
@@ -57,71 +60,37 @@ def get_static_ast(self, root):
5760
self.transfer_from_node_type(self.static_analysis_root)
5861
return self.static_analysis_root
5962

63+
def _apply(self, transformer, node_wrapper, log_level):
64+
transformer(node_wrapper).transform()
65+
self.translator_logger.log_transformed_code(log_level, self.root,
66+
transformer.__name__)
67+
6068
def transfer_from_node_type(self, node_wrapper):
61-
translator_logger = logging_utils.TranslatorLogger()
62-
translator_logger.log(
69+
self.translator_logger.log(
6370
1, "Source code: \n{}".format(ast_to_source_code(self.root)))
6471
# Generic transformation
6572
self.visit(node_wrapper.node)
6673

67-
# Transform basic api of dygraph to static graph and get feed_name_to_arg_name
68-
BasicApiTransformer(node_wrapper).transform()
69-
translator_logger.log_transformed_code(1, self.root,
70-
"BasicApiTransformer")
71-
72-
# Transform Tensor.shape into fluid.layers.shape(Tensor)
73-
TensorShapeTransformer(node_wrapper).transform()
74-
translator_logger.log_transformed_code(2, self.root,
75-
"TensorShapeTransformer")
76-
77-
# Transform list used in control flow
78-
ListTransformer(node_wrapper).transform()
79-
translator_logger.log_transformed_code(3, self.root, "ListTransformer")
80-
81-
# Transform break/continue in loops
82-
BreakContinueTransformer(node_wrapper).transform()
83-
translator_logger.log_transformed_code(4, self.root,
84-
"BreakContinueTransformer")
85-
86-
# Transform return in functions
87-
ReturnTransformer(node_wrapper).transform()
88-
translator_logger.log_transformed_code(5, self.root,
89-
"ReturnTransformer")
90-
91-
# Transform logical and/or/not
92-
LogicalTransformer(node_wrapper).transform()
93-
translator_logger.log_transformed_code(6, self.root,
94-
"LogicalTransformer")
95-
96-
# Transform for loop and while loop
97-
LoopTransformer(node_wrapper).transform()
98-
translator_logger.log_transformed_code(7, self.root, "LoopTransformer")
99-
100-
# Transform all if/else statement of Dygraph into Static Graph.
101-
IfElseTransformer(node_wrapper).transform()
102-
translator_logger.log_transformed_code(8, self.root,
103-
"IfElseTransformer")
104-
105-
# Transform python assert statement
106-
AssertTransformer(node_wrapper).transform()
107-
translator_logger.log_transformed_code(9, self.root,
108-
"AssertTransformer")
109-
110-
# Transform all python print statement
111-
PrintTransformer(node_wrapper).transform()
112-
translator_logger.log_transformed_code(10, self.root,
113-
"PrintTransformer")
114-
115-
# Transform call recursively
116-
CallTransformer(node_wrapper).transform()
117-
translator_logger.log_transformed_code(11, self.root, "CallTransformer")
118-
119-
# Transform python type casting statement
120-
CastTransformer(node_wrapper).transform()
121-
translator_logger.log_transformed_code(12, self.root, "CastTransformer")
122-
123-
translator_logger.log_transformed_code(logging_utils.LOG_AllTransformer,
124-
self.root, "All Transformers")
74+
transformers = [
75+
BasicApiTransformer, # Basic Api
76+
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
77+
ListTransformer, # List used in control flow
78+
BreakContinueTransformer, # break/continue in loops
79+
ReturnTransformer, # return in functions
80+
LogicalTransformer, # logical and/or/not
81+
LoopTransformer, # for/while -> while_op
82+
IfElseTransformer, # if/else -> cond_op
83+
AssertTransformer, # assert statement
84+
PrintTransformer, # print statement
85+
CallTransformer, # transform call recursively
86+
CastTransformer, # type casting statement
87+
]
88+
89+
for index, transformer in enumerate(transformers):
90+
self._apply(transformer, node_wrapper, log_level=index + 1)
91+
92+
self.translator_logger.log_transformed_code(
93+
logging_utils.LOG_AllTransformer, self.root, "All Transformers")
12594

12695
def visit_FunctionDef(self, node):
12796
if self.decorate_func_name is None:

0 commit comments

Comments
 (0)