Skip to content

Conversation

@houj04
Copy link
Contributor

@houj04 houj04 commented Feb 7, 2024

PR types

New features

PR changes

OPs

Description

  • 更新XHPC依赖到最新,因为需要使用新增的XDNN的adamw_v2函数。
  • 之前XPU下的AdamW不支持multi_precisiontrue的情况,现在支持了。
    • 针对KL2:通过算子拼接,在使用混合精度的情况下,在优化器中读取master_param,并转换grad的数据类型,并将结果写到master_param_outs里面。
    • 针对KL3:单独写了一个函数,最大程度抄paddle/phi/kernels/gpu/adamw_kernel.cu
  • 把GPU算子实现中,尾部的若干类型注册,同步到XPU上。
  • 把GPU的单测,包括基础计算,以及和混合精度相关的TestAdamWOpMultiPrecisonWithMainGradTestAdamWOpMultiPrecison这两个类,同步到XPU上。
  • 跑单测的时候发现有类型注册问题,修改了reduce_mean_gradreduce_mean_grad的类型注册,追加了float16类型。
  • 顺手修了python端的几个细节typo。
  • 顺手修了FA的某个单测计算阈值,在bfloat16下面稍微放松一点点。

@paddle-bot
Copy link

paddle-bot bot commented Feb 7, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@GuoxiaWang GuoxiaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT fo typo fix

Galaxy1458
Galaxy1458 previously approved these changes Feb 22, 2024
Copy link
Contributor

@lj970926 lj970926 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

def test_main(self):
xpu_version = core.get_xpu_device_version(0)
if xpu_version != core.XPUVersion.XPU3:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么只针对KL3呀,我看KL2好像也做了支持

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为新写的adamw_v2在XDNN里面只有KL3的实现,当时想着是以KL3为准,KL2不维护,所以这里写的只针对KL3。
另外用算子拼接实现的版本可能也支持KL2,没有特地测试过,稍后我跑一把看看情况。

param = paddle.randn(shape).astype(paddle.bfloat16)
master_weight = param.astype(paddle.float32)
grad = paddle.randn(shape).astype(paddle.bfloat16)
main_grad = grad.astype(paddle.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的main_grad是干什么用的?我看好像和grad数据和类型都完全一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指出,稍后我查一下,目前看是为了debug而有些修改,忘记改回去了。

Copy link
Contributor

@lj970926 lj970926 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for PADDLE_ENFORCE

@QingshuChen QingshuChen merged commit 23fdbd1 into PaddlePaddle:develop Feb 23, 2024
@houj04 houj04 added the XPU label Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants