Skip to content

Conversation

@zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Jul 21, 2021

PR types

New features

PR changes

APIs

Describe

add state_dict and load_state_dict and unittest for class GradScaler
中文文档链接:http://10.136.157.23:8090/documentation/docs/zh/api/paddle/amp/GradScaler_cn.html?reviewVersion=jenkins-doc-review-2-191
英文文档链接:http://10.136.157.23:8090/documentation/docs/zh/api/paddle/amp/GradScaler_cn.html?reviewVersion=jenkins-doc-review-2-191

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

TCChenlong
TCChenlong previously approved these changes Jul 28, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhiqiu
zhiqiu previously approved these changes Aug 2, 2021
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# required: gpu,xpu
import paddle
paddle.set_device('gpu')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的两行注释和代码在中英文文档预览里都没有看到,是生成预览之后新添加的?另外,其他代码示例是否也需要添加类似的代码?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks,新增的注释和代码是在预览之后新增加的,我去更新一下预览代码,paddle.set_device('gpu')可以不添加,后面的GradScaler初始化的时候会针对device进行提醒,这里添加的这行代码已删除。

Args:
new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor.
new_init_loss_scaling(float): The new_init_loss_scaling used to update initial loss scaling factor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下类型变更之后对原来的情况是否完全兼容?比如参数类型检查是否有相关设置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks,原本数据类型是float,初次添加注释的时候错误提交了int,这里修改为了正确的数据类型float。

lanxianghit
lanxianghit previously approved these changes Aug 2, 2021
Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

print('save_load:', out_use_state_dict[0], out_no_state_dict[0])
self.assertTrue(
np.allclose(
out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be equal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks,it is equal after set flag FLAGS_cudnn_deterministic=True.

Comment on lines +373 to +375
paddle.save(scaler.state_dict(), 'ResNet_model.pdparams')
dict_load = paddle.load('ResNet_model.pdparams')
scaler.load_state_dict(dict_load)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if the state value are equal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks,the state values are euqal.

@zhiqiu zhiqiu merged commit 99f8f5c into PaddlePaddle:develop Aug 11, 2021
@zhangbo9674 zhangbo9674 deleted the dev_gradscaler_state_dict branch September 14, 2022 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants