Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,11 @@ def _replace_value_with_input_spec(self, args):
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
_set_spec_stop_gradient(input_var, True)
elif isinstance(input_var, core.VarBase):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var)
_set_spec_stop_gradient(input_var, stop_gradient)

args_with_spec.append(input_var)

Expand Down Expand Up @@ -172,13 +175,15 @@ def to_static_inputs_with_spec(self, input_with_spec, main_program):
block = main_program.global_block()
for i, var_spec in enumerate(flat_input_spec):
if isinstance(var_spec, paddle.static.InputSpec):
stop_gradient = getattr(var_spec, 'stop_gradient', False)
feed_layer = block.create_var(
# TODO(Aurelius84): consider a more elegant way to name this
name=var_spec.name or "feed_%s" % i,
shape=var_spec.shape,
dtype=var_spec.dtype,
is_data=True,
need_check_feed=False)
need_check_feed=False,
stop_gradient=stop_gradient)
else:
feed_layer = var_spec
inputs.append(feed_layer)
Expand Down Expand Up @@ -302,7 +307,7 @@ def check_type_and_len(input, spec, check_length=False):
if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging_utils.warn(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.".
"Please specific InputSpec information in `@to_static` if you expect them as mutable inputs.".
format(type_name(rest_input)))
input_with_spec.extend(inputs[len(input_spec):])

Expand Down Expand Up @@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec):
return processed_specs
else:
return input_spec


def _set_spec_stop_gradient(spec, stop_gradient):
"""
Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op
while append_backward.
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
77 changes: 43 additions & 34 deletions python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class NestSequence(object):

def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input
self.__input_list = self.tolist()
self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check)

Expand All @@ -48,12 +49,12 @@ def restore(self, value_list):
"""
Restores the nested sequence from value list.
"""
assert len(self.tolist()) == len(value_list)
assert len(self.__input_list) == len(value_list)
return pack_sequence_as(self.__raw_input, value_list)

def _get_var_ids(self):
var_ids = []
for idx, var in enumerate(self.tolist()):
for idx, var in enumerate(self.__input_list):
if isinstance(var, (framework.Variable, core.VarBase)):
var_ids.append(idx)

Expand All @@ -65,7 +66,7 @@ def _check_non_variable(self, need_check):
"""
if need_check:
warning_types = set()
for var in self.tolist():
for var in self.__input_list:
if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var))
if warning_types:
Expand All @@ -80,7 +81,7 @@ def var_ids(self):
return self.__var_ids

def __getitem__(self, item):
return self.tolist()[item]
return self.__input_list[item]


class LazyInitialized(object):
Expand All @@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test):
return program


class PartialProgramLayer(layers.Layer):
class PartialProgramLayer:
"""
PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
and execute them as a static subgraph.
Expand Down Expand Up @@ -134,7 +135,9 @@ def __init__(self, main_program, inputs, outputs, parameters=None):
self._params = parameters if parameters is not None else []

self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope()
self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output
self.__fake_vars = _create_fake_var()
# Set default mode to train
self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True
Expand Down Expand Up @@ -217,19 +220,19 @@ def _get_double_grads(self, program):
var_desc.name(),
var_desc.type(), False)
double_grads.append(var_base)
return double_grads
return self._valid_vars(double_grads)

def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)

attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training)
core.ops.run_program(
valid_vars(in_vars),
valid_vars(self._params),
valid_vars(out_vars), tmp_scope_vec,
valid_vars(self._double_grads), *attrs)
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
*attrs)

restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
Expand Down Expand Up @@ -264,33 +267,36 @@ def _prepare(self, inputs):
expected_place):
var = value._copy_to(expected_place, False)
var.stop_gradient = True
var.name = value.name
else:
var = value
var.name = self._inputs[i].desc.name()
else:
continue
input_vars.append(var)

# Create VarBase to receive output data.
out_vars = []
for idx in self._outputs.var_ids:
var = self._outputs[idx]
def create_out(var_id):
var = self._outputs[var_id]
assert isinstance(var, framework.Variable)
var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
out_vars.append(var_base)
return var_base

# Create VarBase to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))

return input_vars, out_vars

def _create_scope_vec(self):
# Hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)

tmp_scope_vec.value().set_scope(self._inner_scope)

return input_vars, out_vars, tmp_scope_vec
inner_scope = core.Scope()
tmp_scope_vec.value().set_scope(inner_scope)
return tmp_scope_vec

def _restore_out(self, out_vars):
"""
Expand All @@ -311,8 +317,9 @@ def _clone_for_test(self, main_program):
return main_program.clone(for_test=True)

def _is_no_value(self, var):
if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
if isinstance(var, core.VarBase) and var.shape == [1]:
# NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True
return False

Expand Down Expand Up @@ -405,20 +412,22 @@ def _check_params_all_inited(self, main_program):
"Please define the layer with parameters in `__init__` function."
% name)

def _valid_vars(self, vars):
"""
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
"""
return vars if vars else self.__fake_vars


def valid_vars(vars):
def _create_fake_var():
"""
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
Create a fake_var (force on CPU) to handle empty input or output
"""
if vars:
return vars
return [
core.VarBase(
value=[1],
name='Fake_var',
place=framework._current_expected_place())
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
core.VarDesc.VarType.RAW, False)
]


Expand Down