Skip to content

Commit 3983c72

Browse files
authored
[DoubleGrad PR #2] Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition (PaddlePaddle#41016)
* [Refactor] refactored eager_gen.py PR #2 * [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes * Fixed minor issue * Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition * Fixed issues * Fixed minor issue
1 parent 93a2f56 commit 3983c72

File tree

1 file changed

+89
-81
lines changed
  • paddle/fluid/eager/auto_code_generator/final_state_generator

1 file changed

+89
-81
lines changed

paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py

Lines changed: 89 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,24 @@ class {} : public egr::GradNodeBase {{
162162
FORWARD_FUNCTION_TEMPLATE = \
163163
"""
164164
{} {}({}) {{
165-
{}
166-
{}
167-
{}
165+
// Dygraph Record Event
166+
{}
167+
// AMP Logic
168+
{}
169+
170+
// Get Input AutoGradMeta
171+
{}
172+
// Forward API Call
173+
{}
174+
// Get Output AutoGradMeta
175+
{}
176+
bool trace_backward = egr::Controller::Instance().HasGrad();
177+
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
178+
// Check Inplace & Bump Inplace Version
179+
{}
180+
{}
181+
// Node Creation
182+
{}
168183
169184
// Returns
170185
return {};
@@ -174,18 +189,8 @@ class {} : public egr::GradNodeBase {{
174189

175190
FORWARD_BODY_TEMPLATE = \
176191
"""
177-
// Get AutoGradMeta
178-
{}
179-
bool trace_backward = egr::Controller::Instance().HasGrad();
180-
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
181-
{}
182-
// Forward API Call
183-
{}
184-
{}
185-
{{
186-
{}
187-
{}
188192
if(require_any_grad) {{
193+
{}
189194
egr::EagerUtils::PassStopGradient({});
190195
191196
// Node Construction
@@ -203,7 +208,6 @@ class {} : public egr::GradNodeBase {{
203208
{}
204209
{}
205210
}}
206-
}}
207211
"""
208212

209213
NAMESPACE_WRAPPER_TEMPLATE = \
@@ -294,7 +298,6 @@ class {} : public egr::GradNodeBase {{
294298

295299
CHECK_INPLACE_TEMPLATE = \
296300
"""
297-
// Check Inplace
298301
egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
299302
"""
300303

@@ -625,7 +628,7 @@ def SlotNameMatching(self):
625628
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
626629
)
627630

628-
def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
631+
def GenerateNodeCreationCodes(self):
629632
forward_api_name = self.forward_api_name
630633
forward_inputs_position_map = self.forward_inputs_position_map
631634
forward_outputs_position_map = self.forward_outputs_position_map
@@ -635,67 +638,14 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
635638
backward_grad_outputs_map = self.backward_grad_outputs_map
636639
backward_attrs_list = self.backward_attrs_list
637640
optional_inputs = self.optional_inputs
638-
inplace_map = self.inplace_map if is_inplaced else {}
639641

640-
# Get Input AutoGradMeta
641-
inputs_autograd_meta_list = []
642-
compute_require_grad_args_list = ["trace_backward"]
643-
for name, (ttype, pos) in forward_inputs_position_map.items():
644-
input_autograd_meta_name = GetAutoGradMetaName(name)
645-
if IsPlainTensorType(ttype):
646-
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
647-
else:
648-
assert IsVectorTensorType(ttype)
649-
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
650-
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
651-
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
652-
653-
inputs_autograd_meta_list.append(input_autograd_meta)
654-
compute_require_grad_args_list.append(input_autograd_meta_name)
655-
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
656-
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
657-
658-
# Get Output AutoGradMeta
659-
outputs_autograd_meta_list = []
642+
# Pass Stop Gradient Args
660643
pass_stop_gradient_args_list = ["false"]
661-
num_fwd_outputs = len(forward_outputs_position_map.keys())
662-
for name, (rtype, pos) in forward_outputs_position_map.items():
644+
for name, (_, _) in forward_outputs_position_map.items():
663645
output_autograd_meta_name = GetAutoGradMetaName(name)
664-
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
665-
if num_fwd_outputs == 1:
666-
if IsPlainTensorType(rtype):
667-
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
668-
else:
669-
assert IsVectorTensorType(rtype)
670-
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
671-
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
672-
else:
673-
# Tuple api_result
674-
if IsPlainTensorType(rtype):
675-
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
676-
else:
677-
assert IsVectorTensorType(rtype)
678-
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
679-
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
680-
681-
outputs_autograd_meta_list.append(output_autograd_meta)
682646
pass_stop_gradient_args_list.append(output_autograd_meta_name)
683-
684-
# ComputeRequireGrad & PassStopGradient
685-
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
686647
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
687648

688-
# Check Inplace
689-
check_inplace_str = ""
690-
bump_inplace_version_str = ""
691-
if is_inplaced:
692-
for inplace_name in inplace_map.keys():
693-
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
694-
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
695-
inplace_name, inplace_autograd_meta_name)
696-
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
697-
inplace_name, inplace_name)
698-
699649
# Node Construction
700650
num_backward_inputs = len(forward_outputs_position_map.keys())
701651
num_backward_outputs = len(forward_inputs_position_map.keys())
@@ -719,6 +669,7 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
719669

720670
# SetTensorWrappers
721671
set_tensor_wrappers_list = []
672+
num_fwd_outputs = len(forward_outputs_position_map.keys())
722673
for name, (atype, is_fwd_input,
723674
pos) in backward_forward_inputs_map.items():
724675
is_optional = (name in optional_inputs)
@@ -794,13 +745,10 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
794745
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
795746

796747
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
797-
inputs_autograd_meta_str, compute_require_grad_args_str,
798-
check_inplace_str, forward_call_str, bump_inplace_version_str,
799-
node_creation_event_str, outputs_autograd_meta_str,
800-
pass_stop_gradient_args_str, node_construction_str,
801-
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
802-
set_edges_str, set_out_rank_str, set_history_str,
803-
set_grad_in_meta_str, set_retain_grad_str)
748+
node_creation_event_str, pass_stop_gradient_args_str,
749+
node_construction_str, set_attributes_str, set_tensor_wrappers_str,
750+
set_grad_out_meta_str, set_edges_str, set_out_rank_str,
751+
set_history_str, set_grad_in_meta_str, set_retain_grad_str)
804752

805753
def run(self):
806754
# Basic Validation Check
@@ -973,7 +921,64 @@ def GenerateForwardDefinition(self, is_inplaced):
973921
returns_str = ", ".join(returns_list)
974922
returns_str = f"std::make_tuple({returns_str})"
975923

976-
self.GenerateNodeCreationCodes(forward_call_str, is_inplaced)
924+
# Node Creation Pre-Processing
925+
# 1. Get Input AutoGradMeta
926+
inputs_autograd_meta_list = []
927+
compute_require_grad_args_list = ["trace_backward"]
928+
for name, (ttype, pos) in forward_inputs_position_map.items():
929+
input_autograd_meta_name = GetAutoGradMetaName(name)
930+
if IsPlainTensorType(ttype):
931+
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
932+
else:
933+
assert IsVectorTensorType(ttype)
934+
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
935+
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
936+
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
937+
938+
inputs_autograd_meta_list.append(input_autograd_meta)
939+
compute_require_grad_args_list.append(input_autograd_meta_name)
940+
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
941+
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
942+
943+
# 2. Get Output AutoGradMeta
944+
outputs_autograd_meta_list = []
945+
num_fwd_outputs = len(forward_outputs_position_map.keys())
946+
for name, (rtype, pos) in forward_outputs_position_map.items():
947+
output_autograd_meta_name = GetAutoGradMetaName(name)
948+
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
949+
if num_fwd_outputs == 1:
950+
if IsPlainTensorType(rtype):
951+
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
952+
else:
953+
assert IsVectorTensorType(rtype)
954+
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
955+
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
956+
else:
957+
# Tuple api_result
958+
if IsPlainTensorType(rtype):
959+
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
960+
else:
961+
assert IsVectorTensorType(rtype)
962+
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
963+
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
964+
965+
outputs_autograd_meta_list.append(output_autograd_meta)
966+
967+
# 3. ComputeRequireGrad & PassStopGradient
968+
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
969+
970+
# 4. Check Inplace
971+
check_inplace_str = ""
972+
bump_inplace_version_str = ""
973+
if is_inplaced:
974+
for inplace_name in inplace_map.keys():
975+
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
976+
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
977+
inplace_name, inplace_autograd_meta_name)
978+
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
979+
inplace_name, inplace_name)
980+
981+
self.GenerateNodeCreationCodes()
977982

978983
node_creation_str = self.node_creation_str
979984
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
@@ -1001,7 +1006,10 @@ def GenerateForwardDefinition(self, is_inplaced):
10011006

10021007
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
10031008
returns_type_str, forward_function_name, inputs_args_definition_str,
1004-
dygraph_event_str, amp_logic_str, node_creation_str, returns_str)
1009+
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
1010+
forward_call_str, outputs_autograd_meta_str,
1011+
compute_require_grad_args_str, check_inplace_str,
1012+
bump_inplace_version_str, node_creation_str, returns_str)
10051013
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
10061014

10071015
logging.info(

0 commit comments

Comments
 (0)