@@ -162,9 +162,24 @@ class {} : public egr::GradNodeBase {{
162162FORWARD_FUNCTION_TEMPLATE = \
163163"""
164164 {} {}({}) {{
165- {}
166- {}
167- {}
165+ // Dygraph Record Event
166+ {}
167+ // AMP Logic
168+ {}
169+
170+ // Get Input AutoGradMeta
171+ {}
172+ // Forward API Call
173+ {}
174+ // Get Output AutoGradMeta
175+ {}
176+ bool trace_backward = egr::Controller::Instance().HasGrad();
177+ bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
178+ // Check Inplace & Bump Inplace Version
179+ {}
180+ {}
181+ // Node Creation
182+ {}
168183
169184 // Returns
170185 return {};
@@ -174,18 +189,8 @@ class {} : public egr::GradNodeBase {{
174189
175190FORWARD_BODY_TEMPLATE = \
176191"""
177- // Get AutoGradMeta
178- {}
179- bool trace_backward = egr::Controller::Instance().HasGrad();
180- bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
181- {}
182- // Forward API Call
183- {}
184- {}
185- {{
186- {}
187- {}
188192 if(require_any_grad) {{
193+ {}
189194 egr::EagerUtils::PassStopGradient({});
190195
191196 // Node Construction
@@ -203,7 +208,6 @@ class {} : public egr::GradNodeBase {{
203208{}
204209{}
205210 }}
206- }}
207211"""
208212
209213NAMESPACE_WRAPPER_TEMPLATE = \
@@ -294,7 +298,6 @@ class {} : public egr::GradNodeBase {{
294298
295299CHECK_INPLACE_TEMPLATE = \
296300"""
297- // Check Inplace
298301 egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
299302"""
300303
@@ -625,7 +628,7 @@ def SlotNameMatching(self):
625628 f"Generated Backward Grad Output Map: { self .backward_grad_outputs_map } "
626629 )
627630
628- def GenerateNodeCreationCodes (self , forward_call_str , is_inplaced ):
631+ def GenerateNodeCreationCodes (self ):
629632 forward_api_name = self .forward_api_name
630633 forward_inputs_position_map = self .forward_inputs_position_map
631634 forward_outputs_position_map = self .forward_outputs_position_map
@@ -635,67 +638,14 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
635638 backward_grad_outputs_map = self .backward_grad_outputs_map
636639 backward_attrs_list = self .backward_attrs_list
637640 optional_inputs = self .optional_inputs
638- inplace_map = self .inplace_map if is_inplaced else {}
639641
640- # Get Input AutoGradMeta
641- inputs_autograd_meta_list = []
642- compute_require_grad_args_list = ["trace_backward" ]
643- for name , (ttype , pos ) in forward_inputs_position_map .items ():
644- input_autograd_meta_name = GetAutoGradMetaName (name )
645- if IsPlainTensorType (ttype ):
646- input_autograd_meta = f" egr::AutogradMeta* { input_autograd_meta_name } = egr::EagerUtils::nullable_autograd_meta({ name } );"
647- else :
648- assert IsVectorTensorType (ttype )
649- input_autograd_meta_vec_name = GetAutoGradMetaVectorName (name )
650- input_autograd_meta = f" std::vector<egr::AutogradMeta*> { input_autograd_meta_vec_name } = egr::EagerUtils::nullable_autograd_meta({ name } );\n "
651- input_autograd_meta += f" std::vector<egr::AutogradMeta*>* { input_autograd_meta_name } = &{ input_autograd_meta_vec_name } ;"
652-
653- inputs_autograd_meta_list .append (input_autograd_meta )
654- compute_require_grad_args_list .append (input_autograd_meta_name )
655- inputs_autograd_meta_str = "\n " .join (inputs_autograd_meta_list )
656- compute_require_grad_args_str = "," .join (compute_require_grad_args_list )
657-
658- # Get Output AutoGradMeta
659- outputs_autograd_meta_list = []
642+ # Pass Stop Gradient Args
660643 pass_stop_gradient_args_list = ["false" ]
661- num_fwd_outputs = len (forward_outputs_position_map .keys ())
662- for name , (rtype , pos ) in forward_outputs_position_map .items ():
644+ for name , (_ , _ ) in forward_outputs_position_map .items ():
663645 output_autograd_meta_name = GetAutoGradMetaName (name )
664- output_autograd_meta_vec_name = GetAutoGradMetaVectorName (name )
665- if num_fwd_outputs == 1 :
666- if IsPlainTensorType (rtype ):
667- output_autograd_meta = f" egr::AutogradMeta* { output_autograd_meta_name } = egr::EagerUtils::autograd_meta(&api_result);"
668- else :
669- assert IsVectorTensorType (rtype )
670- output_autograd_meta = f" std::vector<egr::AutogradMeta*> { output_autograd_meta_vec_name } = egr::EagerUtils::autograd_meta(&api_result);\n "
671- output_autograd_meta += f" std::vector<egr::AutogradMeta*>* { output_autograd_meta_name } = &{ output_autograd_meta_vec_name } ;"
672- else :
673- # Tuple api_result
674- if IsPlainTensorType (rtype ):
675- output_autograd_meta = f" egr::AutogradMeta* { output_autograd_meta_name } = egr::EagerUtils::autograd_meta(&std::get<{ pos } >(api_result));"
676- else :
677- assert IsVectorTensorType (rtype )
678- output_autograd_meta = f" std::vector<egr::AutogradMeta*> { output_autograd_meta_vec_name } = egr::EagerUtils::autograd_meta(&std::get<{ pos } >(api_result));\n "
679- output_autograd_meta += f" std::vector<egr::AutogradMeta*>* { output_autograd_meta_name } = &{ output_autograd_meta_vec_name } ;"
680-
681- outputs_autograd_meta_list .append (output_autograd_meta )
682646 pass_stop_gradient_args_list .append (output_autograd_meta_name )
683-
684- # ComputeRequireGrad & PassStopGradient
685- outputs_autograd_meta_str = "\n " .join (outputs_autograd_meta_list )
686647 pass_stop_gradient_args_str = "," .join (pass_stop_gradient_args_list )
687648
688- # Check Inplace
689- check_inplace_str = ""
690- bump_inplace_version_str = ""
691- if is_inplaced :
692- for inplace_name in inplace_map .keys ():
693- inplace_autograd_meta_name = GetAutoGradMetaName (inplace_name )
694- check_inplace_str += CHECK_INPLACE_TEMPLATE .format (
695- inplace_name , inplace_autograd_meta_name )
696- bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE .format (
697- inplace_name , inplace_name )
698-
699649 # Node Construction
700650 num_backward_inputs = len (forward_outputs_position_map .keys ())
701651 num_backward_outputs = len (forward_inputs_position_map .keys ())
@@ -719,6 +669,7 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
719669
720670 # SetTensorWrappers
721671 set_tensor_wrappers_list = []
672+ num_fwd_outputs = len (forward_outputs_position_map .keys ())
722673 for name , (atype , is_fwd_input ,
723674 pos ) in backward_forward_inputs_map .items ():
724675 is_optional = (name in optional_inputs )
@@ -794,13 +745,10 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
794745 node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\" { node_event_name } \" , paddle::platform::TracerEventType::Operator, 1);\n "
795746
796747 self .node_creation_str = FORWARD_BODY_TEMPLATE .format (
797- inputs_autograd_meta_str , compute_require_grad_args_str ,
798- check_inplace_str , forward_call_str , bump_inplace_version_str ,
799- node_creation_event_str , outputs_autograd_meta_str ,
800- pass_stop_gradient_args_str , node_construction_str ,
801- set_attributes_str , set_tensor_wrappers_str , set_grad_out_meta_str ,
802- set_edges_str , set_out_rank_str , set_history_str ,
803- set_grad_in_meta_str , set_retain_grad_str )
748+ node_creation_event_str , pass_stop_gradient_args_str ,
749+ node_construction_str , set_attributes_str , set_tensor_wrappers_str ,
750+ set_grad_out_meta_str , set_edges_str , set_out_rank_str ,
751+ set_history_str , set_grad_in_meta_str , set_retain_grad_str )
804752
805753 def run (self ):
806754 # Basic Validation Check
@@ -973,7 +921,64 @@ def GenerateForwardDefinition(self, is_inplaced):
973921 returns_str = ", " .join (returns_list )
974922 returns_str = f"std::make_tuple({ returns_str } )"
975923
976- self .GenerateNodeCreationCodes (forward_call_str , is_inplaced )
924+ # Node Creation Pre-Processing
925+ # 1. Get Input AutoGradMeta
926+ inputs_autograd_meta_list = []
927+ compute_require_grad_args_list = ["trace_backward" ]
928+ for name , (ttype , pos ) in forward_inputs_position_map .items ():
929+ input_autograd_meta_name = GetAutoGradMetaName (name )
930+ if IsPlainTensorType (ttype ):
931+ input_autograd_meta = f" egr::AutogradMeta* { input_autograd_meta_name } = egr::EagerUtils::nullable_autograd_meta({ name } );"
932+ else :
933+ assert IsVectorTensorType (ttype )
934+ input_autograd_meta_vec_name = GetAutoGradMetaVectorName (name )
935+ input_autograd_meta = f" std::vector<egr::AutogradMeta*> { input_autograd_meta_vec_name } = egr::EagerUtils::nullable_autograd_meta({ name } );\n "
936+ input_autograd_meta += f" std::vector<egr::AutogradMeta*>* { input_autograd_meta_name } = &{ input_autograd_meta_vec_name } ;"
937+
938+ inputs_autograd_meta_list .append (input_autograd_meta )
939+ compute_require_grad_args_list .append (input_autograd_meta_name )
940+ inputs_autograd_meta_str = "\n " .join (inputs_autograd_meta_list )
941+ compute_require_grad_args_str = "," .join (compute_require_grad_args_list )
942+
943+ # 2. Get Output AutoGradMeta
944+ outputs_autograd_meta_list = []
945+ num_fwd_outputs = len (forward_outputs_position_map .keys ())
946+ for name , (rtype , pos ) in forward_outputs_position_map .items ():
947+ output_autograd_meta_name = GetAutoGradMetaName (name )
948+ output_autograd_meta_vec_name = GetAutoGradMetaVectorName (name )
949+ if num_fwd_outputs == 1 :
950+ if IsPlainTensorType (rtype ):
951+ output_autograd_meta = f" egr::AutogradMeta* { output_autograd_meta_name } = egr::EagerUtils::autograd_meta(&api_result);"
952+ else :
953+ assert IsVectorTensorType (rtype )
954+ output_autograd_meta = f" std::vector<egr::AutogradMeta*> { output_autograd_meta_vec_name } = egr::EagerUtils::autograd_meta(&api_result);\n "
955+ output_autograd_meta += f" std::vector<egr::AutogradMeta*>* { output_autograd_meta_name } = &{ output_autograd_meta_vec_name } ;"
956+ else :
957+ # Tuple api_result
958+ if IsPlainTensorType (rtype ):
959+ output_autograd_meta = f" egr::AutogradMeta* { output_autograd_meta_name } = egr::EagerUtils::autograd_meta(&std::get<{ pos } >(api_result));"
960+ else :
961+ assert IsVectorTensorType (rtype )
962+ output_autograd_meta = f" std::vector<egr::AutogradMeta*> { output_autograd_meta_vec_name } = egr::EagerUtils::autograd_meta(&std::get<{ pos } >(api_result));\n "
963+ output_autograd_meta += f" std::vector<egr::AutogradMeta*>* { output_autograd_meta_name } = &{ output_autograd_meta_vec_name } ;"
964+
965+ outputs_autograd_meta_list .append (output_autograd_meta )
966+
967+ # 3. ComputeRequireGrad & PassStopGradient
968+ outputs_autograd_meta_str = "\n " .join (outputs_autograd_meta_list )
969+
970+ # 4. Check Inplace
971+ check_inplace_str = ""
972+ bump_inplace_version_str = ""
973+ if is_inplaced :
974+ for inplace_name in inplace_map .keys ():
975+ inplace_autograd_meta_name = GetAutoGradMetaName (inplace_name )
976+ check_inplace_str += CHECK_INPLACE_TEMPLATE .format (
977+ inplace_name , inplace_autograd_meta_name )
978+ bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE .format (
979+ inplace_name , inplace_name )
980+
981+ self .GenerateNodeCreationCodes ()
977982
978983 node_creation_str = self .node_creation_str
979984 dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\" { forward_api_name } dygraph\" , paddle::platform::TracerEventType::Operator, 1);"
@@ -1001,7 +1006,10 @@ def GenerateForwardDefinition(self, is_inplaced):
10011006
10021007 self .forward_definition_str += FORWARD_FUNCTION_TEMPLATE .format (
10031008 returns_type_str , forward_function_name , inputs_args_definition_str ,
1004- dygraph_event_str , amp_logic_str , node_creation_str , returns_str )
1009+ dygraph_event_str , amp_logic_str , inputs_autograd_meta_str ,
1010+ forward_call_str , outputs_autograd_meta_str ,
1011+ compute_require_grad_args_str , check_inplace_str ,
1012+ bump_inplace_version_str , node_creation_str , returns_str )
10051013 self .forward_declaration_str += f"{ returns_type_str } { forward_function_name } ({ inputs_args_declaration_str } );\n "
10061014
10071015 logging .info (
0 commit comments