@@ -292,6 +292,7 @@ class {} : public egr::GradNodeBase {{
292292#include "paddle/fluid/eager/utils.h"
293293#include "paddle/fluid/eager/api/utils/global_utils.h"
294294#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
295+ #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
295296#include "paddle/fluid/eager/to_static/run_program_op_node.h"
296297#include "paddle/fluid/eager/nan_inf_utils.h"
297298
@@ -436,6 +437,20 @@ class {} : public egr::GradNodeBase {{
436437}
437438
438439
440+ def ExtractForwardApiNameFormInvoke (invoke_config ):
441+ api_name = invoke_config .split ('(' )[0 ]
442+ if api_name [- 1 ] == '_' :
443+ api_name = api_name [:- 1 ]
444+ return re .search (
445+ r"(?P<api_name>[a-zA-Z0-9_]+)(?P<intermediate>_intermediate)?" ,
446+ api_name ).group ('api_name' )
447+
448+
449+ def IsInvokeForwardApi (api_contents , forward_api_name_list ):
450+ return 'invoke' in api_contents and ExtractForwardApiNameFormInvoke (
451+ api_contents ['invoke' ]) in forward_api_name_list
452+
453+
439454#######################
440455## Generator Helpers ##
441456#######################
@@ -478,7 +493,8 @@ def GenerateCoreOpInfoDefinition():
478493#####################
479494class DygraphFunctionGeneratorBase (FunctionGeneratorBase ):
480495
481- def __init__ (self , forward_api_contents , grad_api_contents , namespace ):
496+ def __init__ (self , forward_api_contents , grad_api_contents ,
497+ forward_apis_dict , namespace ):
482498 self .forward_api_contents = forward_api_contents
483499 # Members from Parent:
484500 #self.namespace
@@ -495,6 +511,7 @@ def __init__(self, forward_api_contents, grad_api_contents, namespace):
495511 #self.forward_inplace_map
496512 FunctionGeneratorBase .__init__ (self , forward_api_contents , namespace )
497513
514+ self .forward_apis_dict = forward_apis_dict
498515 self .grad_api_contents = grad_api_contents
499516
500517 # Raw Contents
@@ -935,9 +952,11 @@ def run(self):
935952
936953class DygraphForwardFunctionGenerator (DygraphFunctionGeneratorBase ):
937954
938- def __init__ (self , forward_api_contents , grad_api_contents , namespace ):
955+ def __init__ (self , forward_api_contents , grad_api_contents ,
956+ forward_apis_dict , namespace ):
939957 DygraphFunctionGeneratorBase .__init__ (self , forward_api_contents ,
940- grad_api_contents , namespace )
958+ grad_api_contents ,
959+ forward_apis_dict , namespace )
941960
942961 # Generated Results
943962 self .forward_definition_str = ""
@@ -1299,10 +1318,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
12991318 def __init__ (self ,
13001319 forward_api_contents ,
13011320 grad_api_contents ,
1321+ forward_apis_dict ,
13021322 namespace ,
13031323 next_grad_api_contents = None ):
13041324 DygraphFunctionGeneratorBase .__init__ (self , forward_api_contents ,
1305- grad_api_contents , namespace )
1325+ grad_api_contents ,
1326+ forward_apis_dict , namespace )
13061327
13071328 # Record name mapping from forward_var_name to grad_var_names
13081329 self .to_next_grad_name_mapping = {} # {name : name}
@@ -1346,6 +1367,7 @@ def RecordGrad2NextGradNameMapping(self, next_node_generator):
13461367 def GenerateHigherOrderNodeCreationCode (self ):
13471368 namespace = self .namespace
13481369 grad_api_contents = self .grad_api_contents
1370+ forward_apis_dict = self .forward_apis_dict
13491371 next_grad_api_contents = self .next_grad_api_contents
13501372
13511373 next_grad_node_creation_str = ""
@@ -1358,7 +1380,8 @@ def GenerateHigherOrderNodeCreationCode(self):
13581380 backward_api_contents = next_grad_api_contents
13591381
13601382 next_node_generator = DygraphFunctionGeneratorBase (
1361- forward_api_contents , backward_api_contents , namespace )
1383+ forward_api_contents , backward_api_contents , forward_apis_dict ,
1384+ namespace )
13621385 next_node_generator .run ()
13631386 next_node_generator .GenerateNodeCreationCodes (True )
13641387
@@ -1443,6 +1466,8 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
14431466 backward_inplace_map = self .backward_inplace_map
14441467 indent = GetIndent (1 )
14451468
1469+ is_invoke_forward_api = IsInvokeForwardApi (self .grad_api_contents ,
1470+ self .forward_apis_dict )
14461471 # Construct grad_api function args
14471472 # Order: TensorWrappers, GradTensors, Attributes
14481473 grad_api_args_len = len (backward_forward_inputs_map .keys ()) + len (
@@ -1575,11 +1600,18 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
15751600 optional_inplace_str = ""
15761601 # Grad Outputs
15771602 out_index = - 1
1603+ out_assign_str = ""
15781604 for name , (ttype , fwd_position ,
15791605 grad_api_position ) in backward_grad_outputs_map .items ():
15801606 transformed_tensor_name = self .TransformToNextGradName (name )
15811607 out_index = out_index + 1
1582- grad_api_args .append (f"api_output_{ out_index } " )
1608+ if is_invoke_forward_api :
1609+ if len (backward_grad_outputs_map ) == 1 :
1610+ out_assign_str += f"{ indent } *api_output_{ out_index } = api_output;\n "
1611+ else :
1612+ out_assign_str += f"{ indent } *api_output_{ out_index } = std::get<{ out_index } >(api_output);\n "
1613+ else :
1614+ grad_api_args .append (f"api_output_{ out_index } " )
15831615 if inplace_grad_input_str in optional_inplace_var_name :
15841616 optional_inplace_str = "VLOG(6) << \" No Inplace should happend for wrappered input: {inplace_grad_input_str}\" ;"
15851617 else :
@@ -1621,7 +1653,24 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
16211653
16221654 grad_api_args_str = ", " .join (grad_api_args )
16231655
1624- grad_function_call_str = f"""
1656+ if is_invoke_forward_api :
1657+ autograd_api_out = "auto"
1658+ if len (self .backward_inplace_map ) > 0 and len (
1659+ backward_grad_outputs_map ) == 1 :
1660+ autograd_api_out = "auto&"
1661+ forward_api_name = self .grad_api_contents ['invoke' ].split (
1662+ '(' )[0 ].strip ()
1663+ autograd_api = self .grad_api_contents ['invoke' ].replace (
1664+ forward_api_name , forward_api_name + '_dygraph_function' , 1 )
1665+ grad_function_call_str = f"""
1666+ if (trace_backward) {{
1667+ { indent } { autograd_api_out } api_output = { autograd_api } ;
1668+ { out_assign_str } }} else {{
1669+ { indent } { autograd_api_out } api_output = paddle::experimental::{ self .namespace } { self .grad_api_contents ['invoke' ]} ;
1670+ { out_assign_str } { indent } }}
1671+ """
1672+ else :
1673+ grad_function_call_str = f"""
16251674{ indent } { grad_api_namespace } { backward_api_name } ({ grad_api_args_str } );"""
16261675
16271676 # Check Nan and Inf
@@ -1631,7 +1680,7 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
16311680 # Prepare for Node Creation if Necessary
16321681 outputs_autograd_meta_str = ""
16331682 compute_require_next_grad_str = ""
1634- if len (next_grad_node_creation_str ) > 0 :
1683+ if len (next_grad_node_creation_str ) > 0 or is_invoke_forward_api :
16351684 compute_require_next_grad_str = f"{ indent } bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n "
16361685
16371686 # 3. Get Output AutoGradMeta
@@ -1754,6 +1803,9 @@ def GetBackwardAPIContents(self, forward_api_contents):
17541803 def GenerateCode (self ):
17551804 forward_api_list = self .forward_api_list
17561805 grad_api_dict = self .grad_api_dict
1806+ forward_apis_dict = {}
1807+ for api_item in forward_api_list :
1808+ forward_apis_dict [api_item ['api' ]] = api_item
17571809 namespace = self .namespace
17581810
17591811 for forward_api_contents in forward_api_list :
@@ -1769,7 +1821,8 @@ def GenerateCode(self):
17691821
17701822 # Generate Dygraph Forward Function
17711823 function_generator = DygraphForwardFunctionGenerator (
1772- forward_api_contents , backward_api_contents , namespace )
1824+ forward_api_contents , backward_api_contents , forward_apis_dict ,
1825+ namespace )
17731826 function_generator .run ()
17741827
17751828 self .forward_definition_str += function_generator .forward_definition_str + "\n "
@@ -1784,6 +1837,7 @@ def GenerateCode(self):
17841837
17851838 node_generator = DygraphNodeGenerator (forward_api_contents ,
17861839 backward_api_contents ,
1840+ forward_apis_dict ,
17871841 namespace ,
17881842 next_grad_api_contents )
17891843 node_generator .run ()
0 commit comments