-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Unify the implementation of activation operation #32348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
bd411e5 to
59d62d1
Compare
adfeb25 to
1d41980
Compare
beb7ed2 to
5c3cf4d
Compare
d910b6d to
711a097
Compare
1e5c724 to
78058c6
Compare
c81aa21 to
52d8151
Compare
c8942d9 to
d8da01f
Compare
767cbb6 to
f67e8a4
Compare
| CT dout = static_cast<CT>(args[0]); | ||
| CT x = static_cast<CT>(args[1]); | ||
| CT temp1 = one + exp(-x); | ||
| CT temp2 = x * exp(-x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样写,exp()会调用2次吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已对实现进行修改
| __device__ __forceinline__ T operator()(const T* args) const { | ||
| CT x = static_cast<CT>(args[0]); | ||
| CT temp = x > zero ? zero : -x; | ||
| return T(-temp - log(exp(-temp) + exp(-x - temp))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然用的都是-temp,那temp计算的时候是不是就可以不要这个负号?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了保持与原实现和公式的统一,还是先不改了
| CT dout = static_cast<CT>(args[0]); | ||
| CT x = static_cast<CT>(args[1]); | ||
| CT temp = x > zero ? zero : -x; | ||
| return T(dout * (exp(-x - temp) / (exp(-temp) + exp(-x - temp)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 同上:既然用的都是
-temp,那temp计算的时候是不是就可以不要这个负号? - 分子、分母都会用到
exp(-x - temp),是不是可以提取出来?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已简化
| // Inputs: args[0], the input x | ||
| __device__ __forceinline__ T operator()(const T* args) const { | ||
| CT x = static_cast<CT>(args[0]); | ||
| return x >= zero ? args[0] : T(static_cast<CT>(alpha) * (exp(x) - one)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂不修改
| __device__ __forceinline__ T operator()(const T* args) const { | ||
| CT dout = static_cast<CT>(args[0]); | ||
| CT x = static_cast<CT>(args[1]); | ||
| return x >= zero ? args[0] : T(dout * static_cast<CT>(alpha) * exp(x)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂不修改
| auto temp1 = x < static_cast<T>(threshold * -1.f); | ||
| auto temp2 = x > static_cast<T>(threshold); | ||
| out.device(d) = x * (temp1 + temp2).template cast<T>(); | ||
| out.device(d) = x * (temp1 || temp2).template cast<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加也没问题吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测里有threshold为负的情况,用加号就会变成两倍
| ThresholdedReluFunctor, | ||
| ThresholdedReluGradFunctor); | ||
| REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor, | ||
| HardSwishGradFunctor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注册的宏有点太长了,后续优化一下吧。
Xreki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and great work~
PR types
Performance optimization
PR changes
OPs
Describe
Unify the implementation of activation operation
本次提交共修改激活算子25个,包括其前向和反向,其中
elu每种类型算子的性能提升近似,因此选取每个类别中的一个算子作为示例进行描述,如下表:
case配置:[16, 128, 257, 257]