-
Notifications
You must be signed in to change notification settings - Fork 6k
Support load state dict form inference model format save result
#26718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d5be234
a4d053b
8291d4e
94115d1
5564882
3625922
a98f3aa
d90ba22
118c34a
5b0d7a7
857b668
72125a6
4334e73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -293,6 +293,8 @@ def __init__(self): | |
| self._model_filename = None | ||
| self._params_filename = None | ||
| self._separate_params = False | ||
| # used for `paddle.load` | ||
| self._keep_name_table = False | ||
|
|
||
| # NOTE: Users rarely use following configs, so these configs are not open to users, | ||
| # reducing user learning costs, but we retain the configuration capabilities | ||
|
|
@@ -600,6 +602,54 @@ def separate_params(self, value): | |
| % type(value)) | ||
| self._separate_params = value | ||
|
|
||
| @property | ||
| def keep_name_table(self): | ||
| """ | ||
| Configures whether keep ``structured_name -> parameter_name`` dict in loaded state dict. | ||
| This dict is the debugging information saved when call `paddle.save`. | ||
| It is generally only used for debugging and does not affect the actual training or inference. | ||
| By default, it will not be retained in `paddle.load` result. Default: False. | ||
| .. note:: | ||
| Only used for ``paddle.load``. | ||
| Examples: | ||
| .. code-block:: python | ||
| import paddle | ||
| paddle.disable_static() | ||
| linear = paddle.nn.Linear(5, 1) | ||
| state_dict = linear.state_dict() | ||
| paddle.save(state_dict, "paddle_dy") | ||
| configs = paddle.SaveLoadConfig() | ||
| configs.keep_name_table = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep_name_table默认设置为True的话,会有什么问题吗?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 也没有什么风险,只是用户载入的state_dict里面会多一些额外信息,这些信息对用户一般没有帮助,保留这个选项主要是为了兼容,怕有同学用利用了旧版实现里面的这个信息
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 或许我们可以先将这个参数作为内部参数,不向用户公开 |
||
| para_state_dict, _ = paddle.load("paddle_dy", configs) | ||
| print(para_state_dict) | ||
| # the name_table is 'StructuredToParameterName@@' | ||
| # {'bias': array([0.], dtype=float32), | ||
| # 'StructuredToParameterName@@': | ||
| # {'bias': u'linear_0.b_0', 'weight': u'linear_0.w_0'}, | ||
| # 'weight': array([[ 0.04230034], | ||
| # [-0.1222527 ], | ||
| # [ 0.7392676 ], | ||
| # [-0.8136974 ], | ||
| # [ 0.01211023]], dtype=float32)} | ||
| """ | ||
| return self._keep_name_table | ||
|
|
||
| @keep_name_table.setter | ||
| def keep_name_table(self, value): | ||
| if not isinstance(value, bool): | ||
| raise TypeError( | ||
| "The SaveLoadConfig.keep_name_table should be bool value, but received input's type is %s." | ||
| % type(value)) | ||
| self._keep_name_table = value | ||
|
|
||
|
|
||
| @switch_to_static_graph | ||
| def save(layer, model_path, input_spec=None, configs=None): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
configs -> config
这里建议用单数,因为SaveLoadConfig是单数形式
config = paddle.SaveLoadConfig()
可以后续统一再改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,在下一个PR中修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里能否将SaveLoadConfig改为SaveLoadConfigs,这样兼容性更好处理一些,因为还涉及到jit.save和jit.load接口中的configs参数