Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 68 additions & 69 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from .base import program_desc_tracing_guard, param_guard
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _current_expected_place as _get_device

__all__ = ['Layer']

Expand Down Expand Up @@ -797,7 +800,7 @@ def _remove_if_exist(*dicts):
raise ValueError(
"super(YourLayer, self).__init__() should be called first")
if len(self._loaddict_holder) > 0:
assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format(
assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in state_dict".format(
value.name)

value.set_value(self._loaddict_holder[value.name])
Expand Down Expand Up @@ -943,12 +946,13 @@ def state_dict(self,
destination = destination_temp
return destination

def set_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
@framework.deprecate_stat_dict
def set_state_dict(self,
state_dict,
include_sublayers=True,
use_structured_name=True):
'''
Set parameters and persistable buffers from stat_dict. All the parameters and buffers will be reset by the tensor in the stat_dict
Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict

Parameters:
state_dict(dict) : Dict contains all the parameters and persistable buffers.
Expand All @@ -961,72 +965,67 @@ def set_dict(self,
Examples:
.. code-block:: python

import paddle.fluid as fluid
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
import paddle

paddle.disable_static()

emb = paddle.nn.Embedding([10, 10])

state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")

para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

emb.set_dict( para_state_dict )
state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy")

para_state_dict, _ = paddle.load("paddle_dy")

'''
self.load_dict(
stat_dict,
include_sublayers=include_sublayers,
use_structured_name=use_structured_name)
emb.set_state_dict(para_state_dict)

def load_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
'''
Set parameters and persistable buffers from stat_dict. All the parameters and persistabl buffers will be reset by the tensor in the stat_dict

This api will be Deprecated. Please use set_dict

Parameters:
state_dict(dict) : Dict contains all the parameters and persistable buffers.
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
Default: True
Returns:
None

Examples:
.. code-block:: python

import paddle.fluid as fluid
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])

state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")

para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

emb.load_dict( para_state_dict )

'''

inner_state_dict = self.state_dict()
def _check_match(key, param):
state = state_dict.get(key, None)
if state is None:
raise ValueError("{} is not found in the provided dict.".format(
key))
if list(state.shape) != list(param.shape):
raise ValueError(
"{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape)))
return param, state

matched_param_state = []
for key, param in self.state_dict().items():
key_name = key if use_structured_name else param.name
try:
match_res = _check_match(key_name, param)
matched_param_state.append(match_res)
except ValueError as err:
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))

if in_dygraph_mode():
for param, state in matched_param_state:
param.set_value(state)
else:

for name, param_or_buffer in inner_state_dict.items():
key_name = name if use_structured_name else param_or_buffer.name
if key_name in stat_dict:
param_or_buffer.set_value(stat_dict[key_name])
else:
raise RuntimeError(
"Parameter or persistable buffer not found, Can't find [ {} ] in stat_dict"
"use_structured_name is set to [{}]".format(
key_name, use_structured_name))
unused_para_list = []
for k, v in stat_dict.items():
if k not in inner_state_dict:
unused_para_list.append(k)
if len(unused_para_list) > 0:
warnings.warn(
"Variables [ {} ] are not used, because not included in layers state_dict".
format(" ".join(unused_para_list)))
def _set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
else:
p = core.Place()
p.set_place(t._place())
place = core.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)

executor = Executor(_get_device())._default_executor
# restore parameter states
core._create_loaded_parameter(
[param for param, state in matched_param_state],
global_scope(), executor)
for param, state in matched_param_state:
_set_var(param, state)

# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
5 changes: 4 additions & 1 deletion python/paddle/fluid/dygraph/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _state_keys(self):
"""
self.keys = ['step_num']

def set_dict(self, state_dict):
def set_state_dict(self, state_dict):
"""
Loads the schedulers state.
"""
Expand All @@ -114,6 +114,9 @@ def set_dict(self, state_dict):
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
)

# [aliases] Compatible with old method names
set_dict = set_state_dict

def step(self):
raise NotImplementedError()

Expand Down
74 changes: 20 additions & 54 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,13 @@ def state_dict(self,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix)

def set_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
@framework.deprecate_stat_dict
def set_state_dict(self,
state_dict,
include_sublayers=True,
use_structured_name=True):
'''
Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict
Set parameters of self._layers from state_dict. All the parameters of self._layers will be reset by the tensor in the state_dict

Parameters:
state_dict(dict) : Dict contains all the parameters
Expand All @@ -605,62 +606,27 @@ def set_dict(self,
Examples:
.. code-block:: python

import paddle.fluid as fluid
with fluid.dygraph.guard():
strategy=fluid.dygraph.prepare_context()
emb = fluid.dygraph.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb, strategy)

state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")

para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

emb.set_dict( para_state_dict )
import paddle

'''

self._layers.set_dict(
stat_dict,
include_sublayers=include_sublayers,
use_structured_name=use_structured_name)

def load_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
'''
Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict

This api will be Deprecated. Please use set_dict

Parameters:
state_dict(dict) : Dict contains all the parameters
include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key.
Default: True
Returns:
None
paddle.disable_static()

Examples:
.. code-block:: python
emb = paddle.nn.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb, strategy)

import paddle.fluid as fluid
with fluid.dygraph.guard():
strategy=fluid.dygraph.prepare_context()
emb = fluid.dygraph.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb, strategy)
state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy")

state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")

para_state_dict, _ = fluid.load_dygraph( "paddle_dy")
para_state_dict, _ = paddle.load("paddle_dy")

emb.load_dict( para_state_dict )
emb.set_state_dict(para_state_dict)

'''

self._layers.load_dict(
stat_dict,
self._layers.set_state_dict(
state_dict,
include_sublayers=include_sublayers,
use_structured_name=use_structured_name)

# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
20 changes: 20 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import unique_name
import paddle.version as fluid_version
import warnings
import functools

__all__ = [
'Program',
Expand Down Expand Up @@ -238,6 +239,25 @@ def __impl__(*args, **kwargs):
return __impl__


# NOTE(chenweihang): There is argument name typo (stat_dict, correct name is state_dict)
# in fluid api Layer.set_dict, Optimizer.load, in order to correct the argument without
# introducing compatibility issues, add this decorator
# NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will
# move kwargs to args, which doesn't work in this decorate case
def deprecate_stat_dict(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'stat_dict' in kwargs:
warnings.warn(
"The argument `stat_dict` has deprecated, please change it to `state_dict`.",
DeprecationWarning)
kwargs['state_dict'] = kwargs['stat_dict']
kwargs.pop('stat_dict')
return func(*args, **kwargs)

return wrapper


dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_)
Expand Down
36 changes: 21 additions & 15 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def state_dict(self):
return state_dict

@framework.dygraph_only
def set_dict(self, state_dict):
def set_state_dict(self, state_dict):
'''
Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be changed.

Expand All @@ -182,20 +182,22 @@ def set_dict(self, state_dict):
Examples:
.. code-block:: python

with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
import paddle

paddle.disable_static()

emb = paddle.nn.Embedding([10, 10])

state_dict = emb.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy")
state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy")

adam = fluid.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000),
adam = paddle.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000),
parameter_list=emb.parameters())
state_dict = adam.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy")
state_dict = adam.state_dict()

para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
para_state_dict, opti_state_dict = paddle.load("paddle_dy")

adam.set_dict(opti_state_dict)
adam.set_state_dict(opti_state_dict)

'''
from paddle.optimizer.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -257,6 +259,9 @@ def set_dict(self, state_dict):

tensor.set(load_para_np, framework._current_expected_place())

# [aliases] Compatible with old method names
set_dict = set_state_dict

def get_opti_var_name_list(self):
return self._opti_name_list

Expand Down Expand Up @@ -4595,15 +4600,16 @@ def _set_checkpoints(self, checkpoints):
), "_checkpoints should be a list of Variable or a list of String"
self._checkpoints = checkpoints

def load(self, stat_dict):
@framework.deprecate_stat_dict
def load(self, state_dict):
"""
:api_attr: Static Graph
:api_attr: Static Graph

load function is not supported by Recompute Optimizer for now.
:return: None

Args:
stat_dict: the dict load by load_persistable method
state_dict: the dict load by load_persistable method

Examples:
.. code-block:: python
Expand All @@ -4627,8 +4633,8 @@ def mlp(input_x, input_y, hid_dim=128, label_dim=2):
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
try:
stat_dict = {}
sgd.load(stat_dict)
state_dict = {}
sgd.load(state_dict)
except NotImplementedError as e:
print(cpt.get_exception_message(e))
"""
Expand Down
Loading