-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[XPU] AdamW support multi_precision #61694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
GuoxiaWang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT fo typo fix
lj970926
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
test/xpu/test_adamw_op_xpu.py
Outdated
| def test_main(self): | ||
| xpu_version = core.get_xpu_device_version(0) | ||
| if xpu_version != core.XPUVersion.XPU3: | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么只针对KL3呀,我看KL2好像也做了支持
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为新写的adamw_v2在XDNN里面只有KL3的实现,当时想着是以KL3为准,KL2不维护,所以这里写的只针对KL3。
另外用算子拼接实现的版本可能也支持KL2,没有特地测试过,稍后我跑一把看看情况。
test/xpu/test_adamw_op_xpu.py
Outdated
| param = paddle.randn(shape).astype(paddle.bfloat16) | ||
| master_weight = param.astype(paddle.float32) | ||
| grad = paddle.randn(shape).astype(paddle.bfloat16) | ||
| main_grad = grad.astype(paddle.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的main_grad是干什么用的?我看好像和grad数据和类型都完全一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢指出,稍后我查一下,目前看是为了debug而有些修改,忘记改回去了。
lj970926
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
chenwhql
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for PADDLE_ENFORCE
PR types
New features
PR changes
OPs
Description
adamw_v2函数。AdamW不支持multi_precision为true的情况,现在支持了。master_param,并转换grad的数据类型,并将结果写到master_param_outs里面。paddle/phi/kernels/gpu/adamw_kernel.cu。TestAdamWOpMultiPrecisonWithMainGrad、TestAdamWOpMultiPrecison这两个类,同步到XPU上。reduce_mean_grad和reduce_mean_grad的类型注册,追加了float16类型。