Skip to content

Conversation

@ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Jul 14, 2021

PR types

Performance optimization

PR changes

OPs

Describe

使用reduce实现broadcast add 反向,并且优化reduce最低维很小的时候的性能。

case pytorch 优化前 优化前相比pytorch 优化后 优化后相比pytorch 加速比
[50, 128, 1000], [128, 1000] 0.16995 0.16955 打平 (0.24%) 0.17211 打平 (2.36%) 0.99
[50, 128, 1000], [1, 128, 1000] 0.16811 0.16793 打平 (0.11%) 0.17109 打平 (1.78%) 0.98
[16, 2048, 7, 7], [16, 2048] 0.05707 0.07289 差于 (27.72%) 0.05359 打平 (0.43%) 1.36
[16, 2048, 16, 16], [16, 2048, 16, 16] 0.29779 0.25073 优于 (15.80%) 0.25122 优于 (15.29%) 1.00
[6, 1, 80, 46080], [1] 0.55631 0.54957 打平 (1.21%) 0.54828 打平 (1.42%) 1.00
[512, 896, 4, 12], [512, 896, 4, 1] 0.66823 2.82644 差于 (3.23x) 0.61699 优于 (7.66%) 4.58
[512, 896, 4, 12], [512, 896, 4, 1] fp16 0.45891 2.71549 差于 (4.92x) 0.39972 优于 (12.86%) 6.79
[32, 12, 128, 128], [32, 1, 1, 128] fp16 0.10937 0.45465 差于 (3.16x) 0.09436 优于 (13.45%) 4.82
[32, 1, 1, 128], [1, 12, 128, 1] fp16 0.11242 0.30392 差于 (1.70x) 0.08767 优于 (29.70%) 3.47

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZzSean ZzSean force-pushed the broadcast_add_bw branch from 1d4e7f3 to 46bb6ae Compare July 14, 2021 03:35
@ZzSean ZzSean marked this pull request as draft July 14, 2021 03:54
@ZzSean ZzSean force-pushed the broadcast_add_bw branch 2 times, most recently from 75c6a5f to ceed224 Compare July 14, 2021 06:09
@ZzSean ZzSean marked this pull request as ready for review July 14, 2021 06:10
@ZzSean ZzSean force-pushed the broadcast_add_bw branch from 8b39e8a to 3e78de0 Compare July 15, 2021 05:56
@ZzSean ZzSean force-pushed the broadcast_add_bw branch from 3e78de0 to 63a1107 Compare July 15, 2021 08:42
@ZzSean ZzSean force-pushed the broadcast_add_bw branch from 4d0b243 to c1018bd Compare July 15, 2021 12:16
@ZzSean ZzSean force-pushed the broadcast_add_bw branch from c1018bd to 7c3f56c Compare July 27, 2021 02:55
@ZzSean ZzSean force-pushed the broadcast_add_bw branch from e1c1389 to eb32654 Compare July 27, 2021 06:27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LastDIm 和Any都走同一套流程是不是可以直接删除LastDim选项,这样能减少一个判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断是必要的,因为对应的index计算不同,对性能倒也没太大影响

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉和下面那个函数功能重复了,是否可以复用GetReduceDim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个函数功能不同,这个是根据broadcast 的输入输出dim去计算reduce dim的,GetReduceDim这个函数是已知reduce dim,计算真正reduce的dim

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数其他reduce功能的op用不到吧?先放elementwise_op_function.h里面?将reduce_op.h都include进来,引入了很多用不到的代码,会影响编译速度吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对应的CUDA实现也可以删掉,L136-L295

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这些判断、TensorCopy逻辑,与ElementwiseAddGradKernel、elementwise_add_grad中的判断有很多重复之处,能不能为CUDADeviceContext特化实现ElementwiseAddGradKernel,简化下代码逻辑?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这是我们inplace逻辑的bug。。。先这样处理吧。。。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数其他reduce功能的op用不到吧?先放elementwise_op_function.h里面?将reduce_op.h都include进来,引入了很多用不到的代码,会影响编译速度吗?

Xreki
Xreki previously approved these changes Aug 20, 2021
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Xreki Xreki merged commit 56c5e21 into PaddlePaddle:develop Aug 22, 2021
@ZzSean ZzSean deleted the broadcast_add_bw branch September 3, 2021 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants