-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[AMP] add state_dict and load_state_dict and unittest for class GradScaler #34300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] add state_dict and load_state_dict and unittest for class GradScaler #34300
Conversation
|
Thanks for your contribution! |
TCChenlong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
zhiqiu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/amp/grad_scaler.py
Outdated
| # required: gpu,xpu | ||
| import paddle | ||
| paddle.set_device('gpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增的两行注释和代码在中英文文档预览里都没有看到,是生成预览之后新添加的?另外,其他代码示例是否也需要添加类似的代码?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks,新增的注释和代码是在预览之后新增加的,我去更新一下预览代码,paddle.set_device('gpu')可以不添加,后面的GradScaler初始化的时候会针对device进行提醒,这里添加的这行代码已删除。
| Args: | ||
| new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor. | ||
| new_init_loss_scaling(float): The new_init_loss_scaling used to update initial loss scaling factor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认一下类型变更之后对原来的情况是否完全兼容?比如参数类型检查是否有相关设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks,原本数据类型是float,初次添加注释的时候错误提交了int,这里修改为了正确的数据类型float。
lanxianghit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) | ||
| self.assertTrue( | ||
| np.allclose( | ||
| out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be equal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks,it is equal after set flag FLAGS_cudnn_deterministic=True.
| paddle.save(scaler.state_dict(), 'ResNet_model.pdparams') | ||
| dict_load = paddle.load('ResNet_model.pdparams') | ||
| scaler.load_state_dict(dict_load) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if the state value are equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks,the state values are euqal.
PR types
New features
PR changes
APIs
Describe
add state_dict and load_state_dict and unittest for class GradScaler
中文文档链接:http://10.136.157.23:8090/documentation/docs/zh/api/paddle/amp/GradScaler_cn.html?reviewVersion=jenkins-doc-review-2-191
英文文档链接:http://10.136.157.23:8090/documentation/docs/zh/api/paddle/amp/GradScaler_cn.html?reviewVersion=jenkins-doc-review-2-191