Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added paddle/http.log
Empty file.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS

from .framework import NoamDecay #DEFINE_ALIAS
Expand Down
168 changes: 98 additions & 70 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,54 @@

import os
import collections
import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle
import six
from . import learning_rate_scheduler
import warnings
from .. import core
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars
from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers

__all__ = [
'save_dygraph',
'load_dygraph',
]


# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def __warn_and_build_configs__(keep_name_table):
warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning)
config = SaveLoadConfig()
config.keep_name_table = keep_name_table
return config

# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass

return func(*args, **kwargs)

return wrapper


@dygraph_only
def save_dygraph(state_dict, model_path):
'''
Expand Down Expand Up @@ -100,41 +134,56 @@ def save_dygraph(state_dict, model_path):

# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
#@dygraph_only
def load_dygraph(model_path, keep_name_table=False):
# @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table
def load_dygraph(model_path, config=None):
'''
:api_attr: imperative

Load parameter state_dict from disk.
Load parameter state dict from disk.

.. note::
Due to some historical reasons, if you load ``state_dict`` from the saved
result of `paddle.io.save_inference_model`, the structured variable name
will cannot be restored. You need to set the argument `use_structured_name=False`
when using `Layer.set_state_dict` later.

Args:
model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams')
keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict.
Default : False
model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams')
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None.

Returns:
state_dict(dict) : the dict store the state_dict

Examples:
.. code-block:: python

import paddle.fluid as fluid
import paddle

with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
paddle.disable_static()

state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
emb = paddle.nn.Embedding([10, 10])

adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
parameter_list = emb.parameters() )
state_dict = adam.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy")

para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
scheduler = paddle.optimizer.lr_scheduler.NoamLR(
d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.optimizer.Adam(
learning_rate=scheduler,
parameters=emb.parameters())
state_dict = adam.state_dict()
paddle.save(state_dict, "paddle_dy")

'''
para_state_dict, opti_state_dict = paddle.load("paddle_dy")

'''
# deal with argument `model_path`
model_prefix = model_path
if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9]
Expand All @@ -145,66 +194,45 @@ def load_dygraph(model_path, keep_name_table=False):
opti_dict = None
params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt"

# deal with argument `configs`
configs = config
if configs is None:
configs = SaveLoadConfig()

if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path):
# Load state dict by `jit.save` save format
# TODO(chenweihang): [Why not support `io.save_infernece_model` save format here]
# Load state dict by `jit.save/io.save_inference_model` save format
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# Although we reluctantly restore the `state_dict` in some scenarios,
# this may not be complete and there are some limitations, so this function
# will be considered later. The limitations include:
# 1. `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_dict`,
# but this argument is currently not public
# 2. if `save_inference_model` save all persistable variables in a single file,
# user need to give the variable name list to load `state_dict`
# `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_state_dict`
# NOTE(chenweihang): `jit.save` doesn't save optimizer state

# 1. check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# 2. load `__variables.info__`
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if not os.path.exists(var_info_path):
raise RuntimeError(
"No target can be loaded. Now only supports loading `state_dict` from "
"the result saved by `imperative.save` and `imperative.jit.save`."
)
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
# 3. load `__variables__`
# TODO(chenweihang): now only supports loading from default save format:
# - all persistable vars saved in one file named `__variables__`
# for other case, we may need to modify the arguments of this API
var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME)
if not os.path.exists(var_file_path):
raise RuntimeError(
"The parameter file to be loaded was not found. "
"Now only supports loading from the default save format, "
"and does not support custom params_filename and "
"save parameters separately.")
# 4. load all persistable vars
load_var_list = []
for name in sorted(extra_var_info):
var = _varbase_creator(name=name, persistable=True)
load_var_list.append(var)
_dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
# 5. construct state_dict
para_dict = dict()
for var in load_var_list:
structured_name = extra_var_info[var.name].get('structured_name',
None)
if structured_name is None:
raise RuntimeError(
"Cannot find saved variable (%s)'s structured name in saved model.",
var.name)
para_dict[structured_name] = var.numpy()
# NOTE: `jit.save` doesn't save optimizer state

# 2. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path,
configs.model_filename)

# 3. load layer parameters & buffers
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
with guard():
persistable_var_dict = _construct_params_and_buffers(
model_prefix,
programs,
configs.separate_params,
configs.params_filename,
append_suffix=False)

# 4. construct state_dict
para_dict = dict()
for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy()
else:
# Load state dict by `save_dygraph` save format
para_dict = {}
Expand All @@ -213,7 +241,7 @@ def load_dygraph(model_path, keep_name_table=False):
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')

if not keep_name_table and "StructuredToParameterName@@" in para_dict:
if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]

if os.path.exists(opti_file_path):
Expand Down
18 changes: 16 additions & 2 deletions python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,15 @@ def _load_persistable_vars(model_path,
return load_var_dict


# NOTE(chenweihang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
no_suffix_var_dict = dict()
for var_name in var_dict:
no_suffix_name = program_holder._suffix_varname_dict[var_name]
no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
return no_suffix_var_dict


def _construct_program_holders(model_path, model_filename=None):
# make sure the path has been checked
program_holder_dict = dict()
Expand Down Expand Up @@ -517,7 +526,8 @@ def _construct_program_holders(model_path, model_filename=None):
def _construct_params_and_buffers(model_path,
programs,
separate_params=False,
params_filename=None):
params_filename=None,
append_suffix=True):
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path,
Expand All @@ -526,6 +536,10 @@ def _construct_params_and_buffers(model_path,
else:
var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename)

if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward'])

return var_dict


Expand Down Expand Up @@ -685,7 +699,7 @@ def _construct(model_path, configs=None):
# 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename)

# 2. load layer parameters & parameter attributes
# 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers(
model_path, programs, separate_params, params_filename)

Expand Down
Loading