3131from codegen_utils import ParseYamlForward , ParseYamlBackward
3232from codegen_utils import FunctionGeneratorBase , YamlGeneratorBase
3333from codegen_utils import ops_to_fill_zero_for_empty_grads
34- from codegen_utils import TransformGradVarNameForDoubleGradGeneration
3534from codegen_utils import AssertMessage , GetIndent
3635
3736
@@ -483,10 +482,8 @@ def ForwardsValidationCheck(self):
483482 orig_forward_returns_list = self .orig_forward_returns_list
484483
485484 for i in range (len (forward_inputs_list )):
486- forward_input_name = forward_inputs_list [i ][0 ]
487485 forward_input_type = forward_inputs_list [i ][1 ]
488486 forward_input_pos = forward_inputs_list [i ][2 ]
489- orig_input_name = orig_forward_inputs_list [i ][0 ]
490487 orig_input_type = orig_forward_inputs_list [i ][1 ]
491488 orig_input_pos = orig_forward_inputs_list [i ][2 ]
492489
@@ -496,11 +493,9 @@ def ForwardsValidationCheck(self):
496493 forward_input_pos , orig_input_pos )
497494
498495 for i in range (len (forward_attrs_list )):
499- orig_attr_name = orig_forward_attrs_list [i ][0 ]
500496 orig_attr_type = orig_forward_attrs_list [i ][1 ]
501497 orig_attr_default = orig_forward_attrs_list [i ][2 ]
502498 orig_attr_pos = orig_forward_attrs_list [i ][3 ]
503- forward_attr_name = forward_attrs_list [i ][0 ]
504499 forward_attr_type = forward_attrs_list [i ][1 ]
505500 forward_attr_default = forward_attrs_list [i ][2 ]
506501 forward_attr_pos = forward_attrs_list [i ][3 ]
@@ -1133,11 +1128,20 @@ def __init__(self,
11331128 DygraphFunctionGeneratorBase .__init__ (self , forward_api_contents ,
11341129 grad_api_contents , namespace )
11351130
1131+ # Record name mapping from forward_api_name to grad_api_names
1132+ self .to_next_grad_name_mapping = {} # {name : name}
1133+
11361134 # Generated Results
11371135 self .node_declaration_str = ""
11381136 self .node_definition_str = ""
11391137 self .next_grad_api_contents = next_grad_api_contents
11401138
1139+ def TransformToNextGradName (self , string ):
1140+ name_mapping = self .to_next_grad_name_mapping
1141+ if string in name_mapping .keys ():
1142+ return name_mapping [string ]
1143+ return string
1144+
11411145 def ResetOptionalInputs (self ):
11421146 namespace = self .namespace
11431147 grad_api_contents = self .grad_api_contents
@@ -1147,6 +1151,22 @@ def ResetOptionalInputs(self):
11471151
11481152 self .optional_inputs = base_generator .optional_inputs
11491153
1154+ def RecordGrad2NextGradNameMapping (self , next_node_generator ):
1155+ next_orig_inputs_list = next_node_generator .orig_forward_inputs_list
1156+ next_orig_returns_list = next_node_generator .orig_forward_returns_list
1157+
1158+ next_forward_inputs_list = next_node_generator .forward_inputs_list
1159+ next_forward_returns_list = next_node_generator .forward_returns_list
1160+ for i in range (len (next_orig_inputs_list )):
1161+ grad_name = next_orig_inputs_list [i ][0 ]
1162+ next_forward_name = next_forward_inputs_list [i ][0 ]
1163+ self .to_next_grad_name_mapping [grad_name ] = next_forward_name
1164+
1165+ for i in range (len (next_orig_returns_list )):
1166+ grad_ret_name = next_orig_returns_list [i ][0 ]
1167+ next_ret_name = next_forward_returns_list [i ][0 ]
1168+ self .to_next_grad_name_mapping [grad_ret_name ] = next_ret_name
1169+
11501170 def GenerateHigherOrderNodeCreationCode (self ):
11511171 namespace = self .namespace
11521172 grad_api_contents = self .grad_api_contents
@@ -1164,6 +1184,8 @@ def GenerateHigherOrderNodeCreationCode(self):
11641184 next_node_generator .GenerateNodeCreationCodes ()
11651185 grad_node_creation_str = next_node_generator .node_creation_str
11661186
1187+ self .RecordGrad2NextGradNameMapping (next_node_generator )
1188+
11671189 return grad_node_creation_str
11681190
11691191 def GenerateNodeDeclaration (self ):
@@ -1253,8 +1275,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
12531275 for name , (_ , is_fwd_input ,
12541276 grad_api_position ), in backward_forward_inputs_map .items ():
12551277 tensor_wrapper_name = GetSavedName (name )
1256- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1257- name )
1278+ transformed_tensor_name = self .TransformToNextGradName (name )
12581279
12591280 is_optional = (name in self .optional_inputs )
12601281 tensor_wrapper_recover_str = f"{ indent } auto { transformed_tensor_name } = egr::EagerUtils::RecoverTensorWrapper(&this->{ tensor_wrapper_name } , this->shared_from_this());"
@@ -1274,8 +1295,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
12741295 # Grad Ins from grads
12751296 for name , (ttype , fwd_position ,
12761297 grad_api_position ) in backward_grad_inputs_map .items ():
1277- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1278- name )
1298+ transformed_tensor_name = self .TransformToNextGradName (name )
12791299
12801300 is_optional = (name in self .optional_inputs )
12811301 if IsPlainTensorType (ttype ):
@@ -1316,8 +1336,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13161336 num_outputs = len (backward_grad_outputs_map .keys ())
13171337 for name , (ttype , fwd_position ,
13181338 grad_api_position ) in backward_grad_outputs_map .items ():
1319- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1320- name )
1339+ transformed_tensor_name = self .TransformToNextGradName (name )
13211340
13221341 if num_outputs == 1 :
13231342 get_tensor_str = f"{ indent } auto& { transformed_tensor_name } = grad_api_result;"
@@ -1339,8 +1358,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13391358 compute_require_grad_args_list = ["trace_backward" ]
13401359 for name , (ttype , pos ,
13411360 grad_api_position ) in backward_grad_inputs_map .items ():
1342- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1343- name )
1361+ transformed_tensor_name = self .TransformToNextGradName (name )
13441362
13451363 input_autograd_meta_name = GetAutoGradMetaName (
13461364 transformed_tensor_name )
@@ -1358,8 +1376,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13581376
13591377 # 2. Get TensorWrapper AutoGradMeta
13601378 for name , (ttype , _ , pos ), in backward_forward_inputs_map .items ():
1361- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1362- name )
1379+ transformed_tensor_name = self .TransformToNextGradName (name )
13631380
13641381 input_autograd_meta_name = GetAutoGradMetaName (
13651382 transformed_tensor_name )
@@ -1382,8 +1399,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13821399 outputs_autograd_meta_list = []
13831400 num_fwd_outputs = len (backward_grad_outputs_map .keys ())
13841401 for name , (rtype , pos , _ ) in backward_grad_outputs_map .items ():
1385- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1386- name )
1402+ transformed_tensor_name = self .TransformToNextGradName (name )
13871403
13881404 output_autograd_meta_name = GetAutoGradMetaName (
13891405 transformed_tensor_name )
@@ -1417,8 +1433,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
14171433 returns_str = f"{ indent } std::vector<std::vector<paddle::experimental::Tensor>> returns({ slot_num_bwd_outputs } );\n "
14181434 for name , (ttype , fwd_position ,
14191435 grad_api_position ) in backward_grad_outputs_map .items ():
1420- transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration (
1421- name )
1436+ transformed_tensor_name = self .TransformToNextGradName (name )
14221437
14231438 # Infer Grad API Return Type
14241439 if num_bwd_outputs == 1 :
@@ -1441,6 +1456,9 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
14411456
14421457 grad_node_name = GetGradNodeName (forward_api_name )
14431458
1459+ if len (grad_node_creation_str ) == 0 :
1460+ grad_node_creation_str = f"if(create_graph) VLOG(3) << \" Higher order grad node for { grad_node_name } has not been implemented yet.\" ;"
1461+
14441462 self .node_definition_str = GRAD_FUNCTION_TEMPLATE .format (
14451463 grad_node_name , fill_zero_str , get_grad_in_args_str , grad_node_name ,
14461464 grad_function_call_str , get_outputs_str , inputs_autograd_meta_str ,
@@ -1457,11 +1475,11 @@ def run(self):
14571475 #####################
14581476 ## Code Generation ##
14591477 #####################
1460- self .GenerateNodeDeclaration ()
1461-
14621478 # Higher-order GradNode generation
14631479 grad_node_creation_str = self .GenerateHigherOrderNodeCreationCode ()
14641480
1481+ self .GenerateNodeDeclaration ()
1482+
14651483 self .GenerateNodeDefinition (grad_node_creation_str )
14661484
14671485
0 commit comments