Skip to content

Commit d8a1097

Browse files
authored
[DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul (PaddlePaddle#41387)
* [Refactor] refactored eager_gen.py PR #2 * [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes * Fixed minor issue * Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition * Fixed issues * Supported higher-order grad node generation * [DoubleGrad PR #4] Supported higher-order GradNode generation * [DoubleGrad #4] Bug Fixes to Double Grad Node Generation * Fixed yaml typo * Fixed yaml typo * fixed minor issues * [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad() * Fixed minor issue * Fixed CI-Inference issue * Fixed CI-inference issues * [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run * Fixed minor issues * Fixed issue with backward graph construction logic * Fixed implementation issues with backward graph reconstruction * Fixed unittest issue * Fixed issues * [DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul * Fixed issues with phi kernel * Added triple grad test case * Fixed minor issue
1 parent 84e8ae7 commit d8a1097

File tree

9 files changed

+215
-37
lines changed

9 files changed

+215
-37
lines changed

paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
########################
2222
### Global Variables ###
2323
########################
24-
ops_to_fill_zero_for_empty_grads = set(
25-
["split_grad", "rnn_grad", "matmul_double_grad"])
24+
ops_to_fill_zero_for_empty_grads = set([
25+
"split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad",
26+
"sigmoid_triple_grad"
27+
])
2628

2729
# For API dispatch used at python-level
2830
# { op_name : [arg_name, ...] }
@@ -171,12 +173,6 @@ def GetForwardFunctionName(string):
171173
return f"{string}_final_state_dygraph_function"
172174

173175

174-
def TransformGradVarNameForDoubleGradGeneration(string):
175-
if IsGradName(string):
176-
string = "grad_" + string[:-5]
177-
return string
178-
179-
180176
def GetIndent(num):
181177
tab = " "
182178
return "".join([tab for i in range(num)])

paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from codegen_utils import ParseYamlForward, ParseYamlBackward
3232
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
3333
from codegen_utils import ops_to_fill_zero_for_empty_grads
34-
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
3534
from codegen_utils import AssertMessage, GetIndent
3635

3736

@@ -483,10 +482,8 @@ def ForwardsValidationCheck(self):
483482
orig_forward_returns_list = self.orig_forward_returns_list
484483

485484
for i in range(len(forward_inputs_list)):
486-
forward_input_name = forward_inputs_list[i][0]
487485
forward_input_type = forward_inputs_list[i][1]
488486
forward_input_pos = forward_inputs_list[i][2]
489-
orig_input_name = orig_forward_inputs_list[i][0]
490487
orig_input_type = orig_forward_inputs_list[i][1]
491488
orig_input_pos = orig_forward_inputs_list[i][2]
492489

@@ -496,11 +493,9 @@ def ForwardsValidationCheck(self):
496493
forward_input_pos, orig_input_pos)
497494

498495
for i in range(len(forward_attrs_list)):
499-
orig_attr_name = orig_forward_attrs_list[i][0]
500496
orig_attr_type = orig_forward_attrs_list[i][1]
501497
orig_attr_default = orig_forward_attrs_list[i][2]
502498
orig_attr_pos = orig_forward_attrs_list[i][3]
503-
forward_attr_name = forward_attrs_list[i][0]
504499
forward_attr_type = forward_attrs_list[i][1]
505500
forward_attr_default = forward_attrs_list[i][2]
506501
forward_attr_pos = forward_attrs_list[i][3]
@@ -1133,11 +1128,20 @@ def __init__(self,
11331128
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
11341129
grad_api_contents, namespace)
11351130

1131+
# Record name mapping from forward_api_name to grad_api_names
1132+
self.to_next_grad_name_mapping = {} # {name : name}
1133+
11361134
# Generated Results
11371135
self.node_declaration_str = ""
11381136
self.node_definition_str = ""
11391137
self.next_grad_api_contents = next_grad_api_contents
11401138

1139+
def TransformToNextGradName(self, string):
1140+
name_mapping = self.to_next_grad_name_mapping
1141+
if string in name_mapping.keys():
1142+
return name_mapping[string]
1143+
return string
1144+
11411145
def ResetOptionalInputs(self):
11421146
namespace = self.namespace
11431147
grad_api_contents = self.grad_api_contents
@@ -1147,6 +1151,22 @@ def ResetOptionalInputs(self):
11471151

11481152
self.optional_inputs = base_generator.optional_inputs
11491153

1154+
def RecordGrad2NextGradNameMapping(self, next_node_generator):
1155+
next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
1156+
next_orig_returns_list = next_node_generator.orig_forward_returns_list
1157+
1158+
next_forward_inputs_list = next_node_generator.forward_inputs_list
1159+
next_forward_returns_list = next_node_generator.forward_returns_list
1160+
for i in range(len(next_orig_inputs_list)):
1161+
grad_name = next_orig_inputs_list[i][0]
1162+
next_forward_name = next_forward_inputs_list[i][0]
1163+
self.to_next_grad_name_mapping[grad_name] = next_forward_name
1164+
1165+
for i in range(len(next_orig_returns_list)):
1166+
grad_ret_name = next_orig_returns_list[i][0]
1167+
next_ret_name = next_forward_returns_list[i][0]
1168+
self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name
1169+
11501170
def GenerateHigherOrderNodeCreationCode(self):
11511171
namespace = self.namespace
11521172
grad_api_contents = self.grad_api_contents
@@ -1164,6 +1184,8 @@ def GenerateHigherOrderNodeCreationCode(self):
11641184
next_node_generator.GenerateNodeCreationCodes()
11651185
grad_node_creation_str = next_node_generator.node_creation_str
11661186

1187+
self.RecordGrad2NextGradNameMapping(next_node_generator)
1188+
11671189
return grad_node_creation_str
11681190

11691191
def GenerateNodeDeclaration(self):
@@ -1253,8 +1275,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
12531275
for name, (_, is_fwd_input,
12541276
grad_api_position), in backward_forward_inputs_map.items():
12551277
tensor_wrapper_name = GetSavedName(name)
1256-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1257-
name)
1278+
transformed_tensor_name = self.TransformToNextGradName(name)
12581279

12591280
is_optional = (name in self.optional_inputs)
12601281
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
@@ -1274,8 +1295,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
12741295
# Grad Ins from grads
12751296
for name, (ttype, fwd_position,
12761297
grad_api_position) in backward_grad_inputs_map.items():
1277-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1278-
name)
1298+
transformed_tensor_name = self.TransformToNextGradName(name)
12791299

12801300
is_optional = (name in self.optional_inputs)
12811301
if IsPlainTensorType(ttype):
@@ -1316,8 +1336,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13161336
num_outputs = len(backward_grad_outputs_map.keys())
13171337
for name, (ttype, fwd_position,
13181338
grad_api_position) in backward_grad_outputs_map.items():
1319-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1320-
name)
1339+
transformed_tensor_name = self.TransformToNextGradName(name)
13211340

13221341
if num_outputs == 1:
13231342
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
@@ -1339,8 +1358,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13391358
compute_require_grad_args_list = ["trace_backward"]
13401359
for name, (ttype, pos,
13411360
grad_api_position) in backward_grad_inputs_map.items():
1342-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1343-
name)
1361+
transformed_tensor_name = self.TransformToNextGradName(name)
13441362

13451363
input_autograd_meta_name = GetAutoGradMetaName(
13461364
transformed_tensor_name)
@@ -1358,8 +1376,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13581376

13591377
# 2. Get TensorWrapper AutoGradMeta
13601378
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
1361-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1362-
name)
1379+
transformed_tensor_name = self.TransformToNextGradName(name)
13631380

13641381
input_autograd_meta_name = GetAutoGradMetaName(
13651382
transformed_tensor_name)
@@ -1382,8 +1399,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
13821399
outputs_autograd_meta_list = []
13831400
num_fwd_outputs = len(backward_grad_outputs_map.keys())
13841401
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
1385-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1386-
name)
1402+
transformed_tensor_name = self.TransformToNextGradName(name)
13871403

13881404
output_autograd_meta_name = GetAutoGradMetaName(
13891405
transformed_tensor_name)
@@ -1417,8 +1433,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
14171433
returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
14181434
for name, (ttype, fwd_position,
14191435
grad_api_position) in backward_grad_outputs_map.items():
1420-
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
1421-
name)
1436+
transformed_tensor_name = self.TransformToNextGradName(name)
14221437

14231438
# Infer Grad API Return Type
14241439
if num_bwd_outputs == 1:
@@ -1441,6 +1456,9 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
14411456

14421457
grad_node_name = GetGradNodeName(forward_api_name)
14431458

1459+
if len(grad_node_creation_str) == 0:
1460+
grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";"
1461+
14441462
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
14451463
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
14461464
grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
@@ -1457,11 +1475,11 @@ def run(self):
14571475
#####################
14581476
## Code Generation ##
14591477
#####################
1460-
self.GenerateNodeDeclaration()
1461-
14621478
# Higher-order GradNode generation
14631479
grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()
14641480

1481+
self.GenerateNodeDeclaration()
1482+
14651483
self.GenerateNodeDefinition(grad_node_creation_str)
14661484

14671485

paddle/phi/infermeta/backward.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,54 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
206206
dz->share_meta(z);
207207
}
208208
}
209+
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
210+
const MetaTensor& y,
211+
const MetaTensor& z,
212+
const MetaTensor& k,
213+
MetaTensor* dx,
214+
MetaTensor* dy,
215+
MetaTensor* dz,
216+
MetaTensor* dk) {
217+
if (dx) {
218+
dx->share_meta(x);
219+
}
220+
if (dy) {
221+
dy->share_meta(y);
222+
}
223+
if (dz) {
224+
dz->share_meta(z);
225+
}
226+
if (dk) {
227+
dk->share_meta(k);
228+
}
229+
}
230+
231+
void GeneralQuinaryGradInferMeta(const MetaTensor& x,
232+
const MetaTensor& y,
233+
const MetaTensor& z,
234+
const MetaTensor& k,
235+
const MetaTensor& l,
236+
MetaTensor* dx,
237+
MetaTensor* dy,
238+
MetaTensor* dz,
239+
MetaTensor* dk,
240+
MetaTensor* dl) {
241+
if (dx) {
242+
dx->share_meta(x);
243+
}
244+
if (dy) {
245+
dy->share_meta(y);
246+
}
247+
if (dz) {
248+
dz->share_meta(z);
249+
}
250+
if (dk) {
251+
dk->share_meta(k);
252+
}
253+
if (dl) {
254+
dl->share_meta(l);
255+
}
256+
}
209257

210258
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
211259
if (dx) {

paddle/phi/infermeta/backward.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,26 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
9696
MetaTensor* dy,
9797
MetaTensor* dz);
9898

99+
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
100+
const MetaTensor& y,
101+
const MetaTensor& z,
102+
const MetaTensor& k,
103+
MetaTensor* dx,
104+
MetaTensor* dy,
105+
MetaTensor* dz,
106+
MetaTensor* dk);
107+
108+
void GeneralQuinaryGradInferMeta(const MetaTensor& x,
109+
const MetaTensor& y,
110+
const MetaTensor& z,
111+
const MetaTensor& k,
112+
const MetaTensor& l,
113+
MetaTensor* dx,
114+
MetaTensor* dy,
115+
MetaTensor* dz,
116+
MetaTensor* dk,
117+
MetaTensor* dl);
118+
99119
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx);
100120

101121
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,

paddle/phi/kernels/activation_grad_kernel.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,18 @@ void EluDoubleGradKernel(const Context& dev_ctx,
125125
template <typename T, typename Context>
126126
void SigmoidDoubleGradKernel(const Context& dev_ctx,
127127
const DenseTensor& out,
128-
const DenseTensor& ddx,
129128
const DenseTensor& dout,
129+
const DenseTensor& ddx,
130130
DenseTensor* dout_new,
131131
DenseTensor* ddout);
132132

133133
template <typename T, typename Context>
134134
void SigmoidTripleGradKernel(const Context& dev_ctx,
135135
const DenseTensor& out,
136-
const DenseTensor& ddx,
137136
const DenseTensor& dout,
138-
const DenseTensor& d_ddout,
137+
const DenseTensor& ddx,
139138
const DenseTensor& d_dout_new,
139+
const DenseTensor& d_ddout,
140140
DenseTensor* d_out_new,
141141
DenseTensor* d_dout,
142142
DenseTensor* d_ddx);

paddle/phi/kernels/impl/activation_grad_impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ void LogitGradKernel(const Context& dev_ctx,
243243
template <typename T, typename Context>
244244
void SigmoidDoubleGradKernel(const Context& dev_ctx,
245245
const DenseTensor& out,
246-
const DenseTensor& ddx,
247246
const DenseTensor& dout,
247+
const DenseTensor& ddx,
248248
DenseTensor* dout_new,
249249
DenseTensor* ddout) {
250250
if (dout_new) {
@@ -262,10 +262,10 @@ void SigmoidDoubleGradKernel(const Context& dev_ctx,
262262
template <typename T, typename Context>
263263
void SigmoidTripleGradKernel(const Context& dev_ctx,
264264
const DenseTensor& out,
265-
const DenseTensor& ddx,
266265
const DenseTensor& dout,
267-
const DenseTensor& d_ddout,
266+
const DenseTensor& ddx,
268267
const DenseTensor& d_dout_new,
268+
const DenseTensor& d_ddout,
269269
DenseTensor* d_out_new,
270270
DenseTensor* d_dout,
271271
DenseTensor* d_ddx) {

paddle/phi/ops/compat/activation_sig.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,13 @@ KernelSignature TanhTripleGradOpArgumentMapping(
139139
KernelSignature SigmoidDoubleGradOpArgumentMapping(
140140
const ArgumentMappingContext& ctx) {
141141
return KernelSignature(
142-
"sigmoid_double_grad", {"Out", "DDX", "DOut"}, {}, {"DOutNew", "DDOut"});
142+
"sigmoid_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
143143
}
144144

145145
KernelSignature SigmoidTripleGradOpArgumentMapping(
146146
const ArgumentMappingContext& ctx) {
147147
return KernelSignature("sigmoid_triple_grad",
148-
{"Out", "DDX", "DOut", "D_DDOut", "D_DOut_New"},
148+
{"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
149149
{},
150150
{"D_OutNew", "D_DOut", "D_DDx"});
151151
}

0 commit comments

Comments
 (0)