Skip to content

Conversation

@ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Apr 19, 2021

PR types

Performance optimization

PR changes

OPs

Describe

Unify the implementation of activation operation
本次提交共修改激活算子25个,包括其前向和反向,其中

  • 三角函数类的 9 个:sin, cos, tan, asin, acos, atan, sinh, cosh, tanh
  • relu 类的 3 个:relu, leaky_relu, elu
  • sigmoid类的 3 个: sigmoid, silu, logsigmoid
  • 舍入类的 3 个:ceil, floor, round
  • 数学运算类的 6 个:sqrt, rsqrt, square, log, exp, reciprocal
  • 缩放类的 1 个:softshrink

每种类型算子的性能提升近似,因此选取每个类别中的一个算子作为示例进行描述,如下表:
case配置:[16, 128, 257, 257]

OP Name FP32 old FP32 new pro FP16 old FP16 new pro
elu fwd 1.6077ms 1.3114ms 22.6% 898.68us 670.20us 34.1%
elu bwd 2.1628ms 1.9057ms 13.5% 1.5737ms 963.07us 63.4%
sigmoid fwd 1.4083ms 1.3123ms 7.3% 1.0360ms 674.24us 53.7%
sigmoid bwd 2.0002ms 1.9059ms 4.9% 1.1890ms 961.18us 23.7%
ceil fwd 1.5198ms 1.3116ms 15.9% 904.18us 670.92us 34.8%
ceil bwd 1.4069ms 603.45us 133.1% 909.00us 302.23us 200%
sin fwd 1.5071ms 1.3132ms 14.8% 989.87us 673.57us 47.0%
sin bwd 2.0647ms 1.9062ms 8.3% 1.3319ms 970.76us 37.2%
sqrt fwd 1.4051ms 1.3121ms 7.1% 950.01us 672.92us 41.2%
sqrt bwd 2.0164ms 1.9069ms 5.7% 1.3418ms 966.90us 38.7%
softshrink fwd 1.5230ms 1.3118ms 16.1% 910.58us 669.97us 35.9%
softshrink bwd 2.0644ms 1.9057ms 8.3% 1.2642ms 963.07us 31.3%

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZzSean ZzSean force-pushed the activation_op_impl branch from bd411e5 to 59d62d1 Compare April 20, 2021 06:19
@ZzSean ZzSean force-pushed the activation_op_impl branch from adfeb25 to 1d41980 Compare April 21, 2021 07:27
@ZzSean ZzSean force-pushed the activation_op_impl branch from beb7ed2 to 5c3cf4d Compare April 22, 2021 02:43
@ZzSean ZzSean force-pushed the activation_op_impl branch 4 times, most recently from d910b6d to 711a097 Compare April 22, 2021 07:02
@ZzSean ZzSean force-pushed the activation_op_impl branch 2 times, most recently from 1e5c724 to 78058c6 Compare April 23, 2021 02:44
@ZzSean ZzSean force-pushed the activation_op_impl branch from c81aa21 to 52d8151 Compare April 23, 2021 08:12
@ZzSean ZzSean force-pushed the activation_op_impl branch from c8942d9 to d8da01f Compare April 25, 2021 06:51
@ZzSean ZzSean force-pushed the activation_op_impl branch from 767cbb6 to f67e8a4 Compare April 26, 2021 11:13
CT dout = static_cast<CT>(args[0]);
CT x = static_cast<CT>(args[1]);
CT temp1 = one + exp(-x);
CT temp2 = x * exp(-x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样写,exp()会调用2次吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已对实现进行修改

__device__ __forceinline__ T operator()(const T* args) const {
CT x = static_cast<CT>(args[0]);
CT temp = x > zero ? zero : -x;
return T(-temp - log(exp(-temp) + exp(-x - temp)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然用的都是-temp,那temp计算的时候是不是就可以不要这个负号?

Copy link
Contributor Author

@ZzSean ZzSean Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了保持与原实现和公式的统一,还是先不改了

CT dout = static_cast<CT>(args[0]);
CT x = static_cast<CT>(args[1]);
CT temp = x > zero ? zero : -x;
return T(dout * (exp(-x - temp) / (exp(-temp) + exp(-x - temp))));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 同上:既然用的都是-temp,那temp计算的时候是不是就可以不要这个负号?
  • 分子、分母都会用到exp(-x - temp),是不是可以提取出来?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已简化

// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
CT x = static_cast<CT>(args[0]);
return x >= zero ? args[0] : T(static_cast<CT>(alpha) * (exp(x) - one));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂不修改

__device__ __forceinline__ T operator()(const T* args) const {
CT dout = static_cast<CT>(args[0]);
CT x = static_cast<CT>(args[1]);
return x >= zero ? args[0] : T(dout * static_cast<CT>(alpha) * exp(x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂不修改

auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 + temp2).template cast<T>();
out.device(d) = x * (temp1 || temp2).template cast<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加也没问题吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测里有threshold为负的情况,用加号就会变成两倍

ThresholdedReluFunctor,
ThresholdedReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor,
HardSwishGradFunctor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注册的宏有点太长了,后续优化一下吧。

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and great work~

@Xreki Xreki merged commit eca8dcc into PaddlePaddle:develop Apr 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants