Skip to content

Commit 1383a2f

Browse files
authored
add tanh_triple_grad composite logic (#56072) (#58657)
* decompose tanh_triple_grad and add it into prim_white_list test=develop * fix TanhTripleGradKernel bugs test=develop * decompose tanh_triple_grad test=develop
1 parent fdd0689 commit 1383a2f

File tree

4 files changed

+74
-2
lines changed

4 files changed

+74
-2
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"subtract_double_grad",
7070
"add_triple_grad",
7171
"silu_double_grad",
72+
"tanh_triple_grad",
7273
]
7374

7475
# dict of special api that forward api's output will affect bacward api's output

paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,76 @@ void tanh_double_grad(const Tensor& out,
5353
}
5454
}
5555

56+
template <typename T>
57+
void tanh_triple_grad(const Tensor& out,
58+
const Tensor& grad_out_forward,
59+
const Tensor& grad_x_grad_forward,
60+
const paddle::optional<Tensor>& grad_out_new_grad,
61+
const paddle::optional<Tensor>& grad_out_grad_grad,
62+
Tensor* out_grad,
63+
Tensor* grad_out_forward_grad,
64+
Tensor* grad_x_grad_forward_grad) {
65+
if (out_grad) {
66+
if (grad_out_grad_grad) {
67+
if (grad_out_new_grad) {
68+
auto out_grad_tmp =
69+
(-2 * out * grad_x_grad_forward * grad_out_grad_grad.get()) -
70+
(2 * grad_out_forward * grad_x_grad_forward *
71+
grad_out_new_grad.get());
72+
set_output<T>(out_grad_tmp, out_grad);
73+
} else {
74+
auto out_grad_tmp =
75+
-2 * out * grad_x_grad_forward * grad_out_grad_grad.get();
76+
set_output<T>(out_grad_tmp, out_grad);
77+
}
78+
} else {
79+
if (grad_out_new_grad) {
80+
auto out_grad_tmp = -(2 * grad_out_forward * grad_x_grad_forward *
81+
grad_out_new_grad.get());
82+
set_output<T>(out_grad_tmp, out_grad);
83+
} else {
84+
auto out_grad_tmp = 0 * out;
85+
set_output<T>(out_grad_tmp, out_grad);
86+
}
87+
}
88+
}
89+
90+
if (grad_out_forward_grad) {
91+
if (grad_out_new_grad) {
92+
auto grad_out_forward_grad_tmp =
93+
-2 * out * grad_x_grad_forward * grad_out_new_grad.get();
94+
set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
95+
} else {
96+
auto grad_out_forward_grad_tmp = 0 * out;
97+
set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
98+
}
99+
}
100+
101+
if (grad_x_grad_forward_grad) {
102+
if (grad_out_grad_grad) {
103+
if (grad_out_new_grad) {
104+
auto grad_x_grad_forward_grad_tmp =
105+
(1 - (out * out)) * grad_out_grad_grad.get() -
106+
2 * out * grad_out_forward * grad_out_new_grad.get();
107+
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
108+
} else {
109+
auto grad_x_grad_forward_grad_tmp =
110+
(1 - (out * out)) * grad_out_grad_grad.get();
111+
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
112+
}
113+
} else {
114+
if (grad_out_new_grad) {
115+
auto grad_x_grad_forward_grad_tmp =
116+
-(2 * out * grad_out_forward * grad_out_new_grad.get());
117+
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
118+
} else {
119+
auto grad_x_grad_forward_grad_tmp = 0 * grad_x_grad_forward;
120+
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
121+
}
122+
}
123+
}
124+
}
125+
56126
template <typename T>
57127
void matmul_double_grad(const Tensor& x,
58128
const Tensor& y,

paddle/phi/api/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2144,6 +2144,7 @@
21442144
param : [out, out, grad_x_grad_forward]
21452145
kernel :
21462146
func : tanh_triple_grad
2147+
composite : tanh_triple_grad(out, grad_out_forward, grad_x_grad_forward, grad_out_new_grad, grad_out_grad_grad, out_grad, grad_out_forward_grad, grad_x_grad_forward_grad)
21472148
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
21482149
optional : grad_out_new_grad, grad_out_grad_grad
21492150

paddle/phi/kernels/impl/activation_grad_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,11 @@ void TanhTripleGradKernel(const Context& dev_ctx,
189189
dev_ctx.template Alloc<T>(d_dout);
190190
}
191191
if (d_out_new) {
192-
d_dout->Resize(out.dims());
192+
d_out_new->Resize(out.dims());
193193
dev_ctx.template Alloc<T>(d_out_new);
194194
}
195195
if (d_ddx) {
196-
d_dout->Resize(ddx.dims());
196+
d_ddx->Resize(ddx.dims());
197197
dev_ctx.template Alloc<T>(d_ddx);
198198
}
199199
funcs::TanhTripleGradFunctor<T> functor;

0 commit comments

Comments
 (0)