Skip to content

Commit 82630f3

Browse files
authored
[Dy2stat] Add Support for paddle.grad (#33110)
This PR made these changes to support double grad: 1. Translate `paddle.grad` to `paddle.static.gradients` to support double grad for dy2stat. 2. Fix IfElseTransformer bug which may not change value if "Store before Load" variable is in "Store" statement is in IfElse conditional statement 3. Add `DOut` to support double grad variables in `run_program_op` 4. Add support for renaming for double grads for `jit.save/load`
1 parent 1e9299a commit 82630f3

9 files changed

Lines changed: 341 additions & 26 deletions

File tree

paddle/fluid/operators/run_program_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
8383
"contains at most one scope."
8484
"NOTE: Do not use Scope directly because Scope output is not "
8585
"currently supported.");
86+
AddOutput("DOut",
87+
"(vector<LoDTensor>)"
88+
"The output tensors for GRAD Tensors in RunProgram forward "
89+
"operator, the forward operator contains GRAD Tensors when it "
90+
"computes double grad.")
91+
.AsDuplicable()
92+
.AsDispensable();
8693
AddAttr<BlockDesc*>("global_block",
8794
"(BlockDesc *)"
8895
"The global block of executed program desc.");
@@ -154,6 +161,7 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
154161
grad_op->SetInput("Params", this->Input("Params"));
155162
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
156163
grad_op->SetInput("OutScope", this->Output("OutScope"));
164+
grad_op->SetInput("DOut", this->Output("DOut"));
157165
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
158166
grad_op->SetOutput(framework::GradVarName("Params"),
159167
this->InputGrad("Params"));

paddle/fluid/operators/run_program_op.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
131131
const std::vector<std::string> &var_names,
132132
framework::Scope *scope) {
133133
for (size_t i = 0; i < vars.size(); ++i) {
134+
if (var_names[i] == "Fake_var") {
135+
continue;
136+
}
134137
auto *var = scope->Var(var_names[i]);
135138
CheckInputVarStatus(*vars[i], var_names[i]);
136139
VariableShare(*vars[i], var);
@@ -141,9 +144,9 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
141144
const std::vector<std::string> &var_names,
142145
framework::Scope *scope) {
143146
for (size_t i = 0; i < vars.size(); ++i) {
144-
if (var_names[i] == framework::kEmptyVarName) {
145-
VLOG(2) << "find variable name is " << framework::kEmptyVarName
146-
<< ", skip it!";
147+
if (var_names[i] == framework::kEmptyVarName ||
148+
var_names[i] == "Fake_var") {
149+
VLOG(2) << "find variable name is " << var_names[i] << ", skip it!";
147150
continue;
148151
}
149152
// NOTE: Here skip not found var is dangerous, if a bug is caused here,
@@ -170,9 +173,11 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
170173
auto &input_vars = ctx.MultiInputVar("X");
171174
auto &param_vars = ctx.MultiInputVar("Params");
172175
auto output_vars = ctx.MultiOutputVar("Out");
176+
auto dout_vars = ctx.MultiOutputVar("DOut");
173177

174178
auto input_var_names = ctx.InputNames("X");
175179
auto output_var_names = ctx.OutputNames("Out");
180+
auto dout_var_names = ctx.OutputNames("DOut");
176181

177182
// current program may not hold parameters
178183
std::vector<std::string> param_names;
@@ -195,7 +200,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
195200
// Step 2. prepare executor and init persistable variables
196201
framework::Executor exe(ctx.GetPlace());
197202
auto exe_ctx = framework::GetExecutorInfoFromCache(
198-
exe, ctx, {output_var_names}, /*is_grad=*/false);
203+
exe, ctx, {output_var_names, dout_var_names}, /*is_grad=*/false);
199204

200205
// NOTE(Aurelius84): While training some models, forward can be called many
201206
// times and then apply backpropagation all at once, such as Reinforcement
@@ -219,6 +224,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
219224

220225
// Step 4. Get Output
221226
details::ShareVarsFromScope(output_vars, output_var_names, &scope);
227+
details::ShareVarsFromScope(dout_vars, dout_var_names, &scope);
222228

223229
// Debug info: scope info when run end
224230
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());

python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer
2626
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
2727
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
28+
from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer
2829
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
2930
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
3031
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
@@ -86,6 +87,7 @@ def transfer_from_node_type(self, node_wrapper):
8687
PrintTransformer, # print statement
8788
CallTransformer, # transform call recursively
8889
CastTransformer, # type casting statement
90+
GradTransformer, # transform paddle.grad to paddle.gradients
8991
]
9092

9193
for index, transformer in enumerate(transformers):
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import gast
18+
import warnings
19+
20+
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
21+
from paddle.fluid.dygraph.dygraph_to_static import utils
22+
23+
24+
class GradTransformer(gast.NodeTransformer):
25+
"""
26+
A class transforms dygraph paddle.grad to static graph paddle.gradients. The
27+
transformation is applied to support double grad mode.
28+
"""
29+
30+
def __init__(self, wrapper_root):
31+
assert isinstance(
32+
wrapper_root, AstNodeWrapper
33+
), "Input non-AstNodeWrapper node for the initialization of GradTransformer."
34+
self.wrapper_root = wrapper_root
35+
self.root = wrapper_root.node
36+
37+
def transform(self):
38+
self.visit(self.root)
39+
40+
def visit_Call(self, node):
41+
self.generic_visit(node)
42+
if not is_grad_api_node(node):
43+
return node
44+
45+
dygraph_grad_parameters = [
46+
"outputs", "inputs", "grad_outputs", "retain_graph", "create_graph",
47+
"only_inputs", "allow_unused", "no_grad_vars"
48+
]
49+
to_static_grad_param = {
50+
"outputs": "targets",
51+
"inputs": "inputs",
52+
"grad_outputs": "target_gradients",
53+
"no_grad_vars": "no_grad_set"
54+
}
55+
static_keywords = []
56+
57+
for kw in node.keywords:
58+
if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param:
59+
warnings.warn("paddle.grad has unsupported parameter in jit: " +
60+
kw.arg + ", jit will discard it")
61+
continue
62+
dygraph_grad_parameters.remove(kw.arg)
63+
kw.arg = to_static_grad_param[kw.arg]
64+
static_keywords.append(kw)
65+
66+
for i in range(len(node.args)):
67+
arg_name = dygraph_grad_parameters[i]
68+
if arg_name not in to_static_grad_param:
69+
warnings.warn("paddle.grad has unsupported parameter in jit: " +
70+
kw.arg + ", jit will discard it")
71+
continue
72+
kw = gast.keyword(
73+
arg=to_static_grad_param[arg_name], value=node.args[i])
74+
static_keywords.append(kw)
75+
76+
node.func = gast.parse('paddle.static.gradients').body[0].value
77+
node.keywords = static_keywords
78+
node.args = []
79+
return node
80+
81+
82+
def is_grad_api_node(node):
83+
assert isinstance(node, gast.Call)
84+
api_name = utils.ast_to_source_code(node.func).strip()
85+
if utils.is_paddle_api(node):
86+
return api_name.endswith("grad")
87+
return False

python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def _modified_vars(child_dict, parent_dict):
402402
var for var in _vars_with_store(child_dict) if var in parent_dict
403403
])
404404

405-
def _vars_loaded_before_store(ids_dict):
405+
def _vars_loaded(ids_dict):
406406
"""
407407
gast.Param is also a kind of `load` semantic.
408408
"""
@@ -411,8 +411,6 @@ def _vars_loaded_before_store(ids_dict):
411411
for ctx in ctxs:
412412
if isinstance(ctx, (gast.Load, gast.Param)):
413413
new_dict[k].append(ctx)
414-
elif isinstance(ctx, gast.Store):
415-
break
416414
return new_dict
417415

418416
# modified vars
@@ -439,8 +437,12 @@ def _vars_loaded_before_store(ids_dict):
439437
new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars
440438

441439
# 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
440+
# TODO(zhhsplendid): the _vars_loaded can be optimized as _vars_loaded_before_store. Because if a variable is stored before load,
441+
# the value would change by the store statement, we don't have to return to change the value. However, analysis is
442+
# complex because if the IfElse is nested and outer IfElse store statement may not run at all. We will put this optimization
443+
# as the future TODO
442444
used_vars_after_ifelse = set(
443-
[var for var in _vars_loaded_before_store(after_ifelse_vars_dict)])
445+
[var for var in _vars_loaded(after_ifelse_vars_dict)])
444446
new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse
445447

446448
# 4. generate return_ids of if/else node.

python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, main_program, inputs, outputs, parameters=None):
135135
self._origin_main_program = self._verify_program(main_program)
136136
self._inner_scope = core.Scope()
137137
# Set default mode to train
138+
self._double_grads = self._get_double_grads(self._origin_main_program)
138139
self.training = True
139140

140141
@LazyInitialized
@@ -192,24 +193,44 @@ def _prune_unused_params(self, program):
192193
"""
193194
required_params = []
194195
for param in self._params:
196+
found_param = False
195197
for block in program.blocks:
196-
if param.name in block.vars:
197-
required_params.append(param)
198+
for op in block.ops:
199+
if param.name in op.input_arg_names or param.name in op.output_arg_names:
200+
required_params.append(param)
201+
found_param = True
202+
break
203+
if found_param:
198204
break
199205

200206
self._params = required_params
201207

208+
def _get_double_grads(self, program):
209+
double_grads = []
210+
for block in program.blocks:
211+
for name in block.vars:
212+
if "@GRAD" in name:
213+
var_desc = block.vars[name].desc
214+
var_base = core.VarBase(var_desc.dtype(),
215+
var_desc.shape(),
216+
var_desc.name(),
217+
var_desc.type(), False)
218+
double_grads.append(var_base)
219+
return double_grads
220+
202221
def forward(self, inputs):
203222
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
204-
205223
framework._dygraph_tracer().trace_op(
206224
type='run_program',
207225
inputs={
208226
'X': valid_vars(in_vars),
209227
'Params': valid_vars(self._params)
210228
},
211-
outputs={'Out': valid_vars(out_vars),
212-
'OutScope': tmp_scope_vec},
229+
outputs={
230+
'Out': valid_vars(out_vars),
231+
'OutScope': tmp_scope_vec,
232+
'DOut': valid_vars(self._double_grads)
233+
},
213234
attrs={
214235
'global_block': self.program.desc.block(0),
215236
'start_op_index': 0,

0 commit comments

Comments
 (0)