@@ -47,6 +47,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
4747 Main class to transform Dygraph to Static Graph
4848 """
4949
50+ def __init__ (self ):
51+ self .translator_logger = logging_utils .TranslatorLogger ()
52+
5053 def get_static_ast (self , root ):
5154 # save root for some analysis may need global AST
5255 self .root = root
@@ -57,71 +60,37 @@ def get_static_ast(self, root):
5760 self .transfer_from_node_type (self .static_analysis_root )
5861 return self .static_analysis_root
5962
63+ def _apply (self , transformer , node_wrapper , log_level ):
64+ transformer (node_wrapper ).transform ()
65+ self .translator_logger .log_transformed_code (log_level , self .root ,
66+ transformer .__name__ )
67+
6068 def transfer_from_node_type (self , node_wrapper ):
61- translator_logger = logging_utils .TranslatorLogger ()
62- translator_logger .log (
69+ self .translator_logger .log (
6370 1 , "Source code: \n {}" .format (ast_to_source_code (self .root )))
6471 # Generic transformation
6572 self .visit (node_wrapper .node )
6673
67- # Transform basic api of dygraph to static graph and get feed_name_to_arg_name
68- BasicApiTransformer (node_wrapper ).transform ()
69- translator_logger .log_transformed_code (1 , self .root ,
70- "BasicApiTransformer" )
71-
72- # Transform Tensor.shape into fluid.layers.shape(Tensor)
73- TensorShapeTransformer (node_wrapper ).transform ()
74- translator_logger .log_transformed_code (2 , self .root ,
75- "TensorShapeTransformer" )
76-
77- # Transform list used in control flow
78- ListTransformer (node_wrapper ).transform ()
79- translator_logger .log_transformed_code (3 , self .root , "ListTransformer" )
80-
81- # Transform break/continue in loops
82- BreakContinueTransformer (node_wrapper ).transform ()
83- translator_logger .log_transformed_code (4 , self .root ,
84- "BreakContinueTransformer" )
85-
86- # Transform return in functions
87- ReturnTransformer (node_wrapper ).transform ()
88- translator_logger .log_transformed_code (5 , self .root ,
89- "ReturnTransformer" )
90-
91- # Transform logical and/or/not
92- LogicalTransformer (node_wrapper ).transform ()
93- translator_logger .log_transformed_code (6 , self .root ,
94- "LogicalTransformer" )
95-
96- # Transform for loop and while loop
97- LoopTransformer (node_wrapper ).transform ()
98- translator_logger .log_transformed_code (7 , self .root , "LoopTransformer" )
99-
100- # Transform all if/else statement of Dygraph into Static Graph.
101- IfElseTransformer (node_wrapper ).transform ()
102- translator_logger .log_transformed_code (8 , self .root ,
103- "IfElseTransformer" )
104-
105- # Transform python assert statement
106- AssertTransformer (node_wrapper ).transform ()
107- translator_logger .log_transformed_code (9 , self .root ,
108- "AssertTransformer" )
109-
110- # Transform all python print statement
111- PrintTransformer (node_wrapper ).transform ()
112- translator_logger .log_transformed_code (10 , self .root ,
113- "PrintTransformer" )
114-
115- # Transform call recursively
116- CallTransformer (node_wrapper ).transform ()
117- translator_logger .log_transformed_code (11 , self .root , "CallTransformer" )
118-
119- # Transform python type casting statement
120- CastTransformer (node_wrapper ).transform ()
121- translator_logger .log_transformed_code (12 , self .root , "CastTransformer" )
122-
123- translator_logger .log_transformed_code (logging_utils .LOG_AllTransformer ,
124- self .root , "All Transformers" )
74+ transformers = [
75+ BasicApiTransformer , # Basic Api
76+ TensorShapeTransformer , # Tensor.shape -> layers.shape(Tensor)
77+ ListTransformer , # List used in control flow
78+ BreakContinueTransformer , # break/continue in loops
79+ ReturnTransformer , # return in functions
80+ LogicalTransformer , # logical and/or/not
81+ LoopTransformer , # for/while -> while_op
82+ IfElseTransformer , # if/else -> cond_op
83+ AssertTransformer , # assert statement
84+ PrintTransformer , # print statement
85+ CallTransformer , # transform call recursively
86+ CastTransformer , # type casting statement
87+ ]
88+
89+ for index , transformer in enumerate (transformers ):
90+ self ._apply (transformer , node_wrapper , log_level = index + 1 )
91+
92+ self .translator_logger .log_transformed_code (
93+ logging_utils .LOG_AllTransformer , self .root , "All Transformers" )
12594
12695 def visit_FunctionDef (self , node ):
12796 if self .decorate_func_name is None :
0 commit comments