Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,7 @@ struct CopySignGradXYFunctor {
if (x == static_cast<InT>(0))
outs[0] = static_cast<OutT>(0);
else
outs[0] = static_cast<OutT>(dout * (funcs::copysign_func(x, y)) / x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样改有依据吗?比如看下torch是这样改的吗?我怎么感觉这样改虽然解决了下溢问题但可能带来新的精度问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样修改的一个依据是这个kernel的原作者针对copysign的反向写了3个GradFunctor,在CopySignGradXFunctor和CopySignGradXYFunctor中,对于x的求导顺序出现了差别,说明是作者在编程时错写了CopySignGradXYFunctor针对X的求导,将右括号放错了位置。
0d15166e28aa20c32dfbb2231387601c
当然,我们可以参考torch的实现。
383a79de6fc6ddd98f7f7e64e34a722b
同样是先除再乘。

outs[0] = static_cast<OutT>(dout * (funcs::copysign_func(x, y) / x));
// dy = 0
outs[1] = static_cast<OutT>(0);
return outs;
Expand Down