@@ -53,6 +53,76 @@ void tanh_double_grad(const Tensor& out,
5353 }
5454}
5555
56+ template <typename T>
57+ void tanh_triple_grad (const Tensor& out,
58+ const Tensor& grad_out_forward,
59+ const Tensor& grad_x_grad_forward,
60+ const paddle::optional<Tensor>& grad_out_new_grad,
61+ const paddle::optional<Tensor>& grad_out_grad_grad,
62+ Tensor* out_grad,
63+ Tensor* grad_out_forward_grad,
64+ Tensor* grad_x_grad_forward_grad) {
65+ if (out_grad) {
66+ if (grad_out_grad_grad) {
67+ if (grad_out_new_grad) {
68+ auto out_grad_tmp =
69+ (-2 * out * grad_x_grad_forward * grad_out_grad_grad.get ()) -
70+ (2 * grad_out_forward * grad_x_grad_forward *
71+ grad_out_new_grad.get ());
72+ set_output<T>(out_grad_tmp, out_grad);
73+ } else {
74+ auto out_grad_tmp =
75+ -2 * out * grad_x_grad_forward * grad_out_grad_grad.get ();
76+ set_output<T>(out_grad_tmp, out_grad);
77+ }
78+ } else {
79+ if (grad_out_new_grad) {
80+ auto out_grad_tmp = -(2 * grad_out_forward * grad_x_grad_forward *
81+ grad_out_new_grad.get ());
82+ set_output<T>(out_grad_tmp, out_grad);
83+ } else {
84+ auto out_grad_tmp = 0 * out;
85+ set_output<T>(out_grad_tmp, out_grad);
86+ }
87+ }
88+ }
89+
90+ if (grad_out_forward_grad) {
91+ if (grad_out_new_grad) {
92+ auto grad_out_forward_grad_tmp =
93+ -2 * out * grad_x_grad_forward * grad_out_new_grad.get ();
94+ set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
95+ } else {
96+ auto grad_out_forward_grad_tmp = 0 * out;
97+ set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
98+ }
99+ }
100+
101+ if (grad_x_grad_forward_grad) {
102+ if (grad_out_grad_grad) {
103+ if (grad_out_new_grad) {
104+ auto grad_x_grad_forward_grad_tmp =
105+ (1 - (out * out)) * grad_out_grad_grad.get () -
106+ 2 * out * grad_out_forward * grad_out_new_grad.get ();
107+ set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
108+ } else {
109+ auto grad_x_grad_forward_grad_tmp =
110+ (1 - (out * out)) * grad_out_grad_grad.get ();
111+ set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
112+ }
113+ } else {
114+ if (grad_out_new_grad) {
115+ auto grad_x_grad_forward_grad_tmp =
116+ -(2 * out * grad_out_forward * grad_out_new_grad.get ());
117+ set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
118+ } else {
119+ auto grad_x_grad_forward_grad_tmp = 0 * grad_x_grad_forward;
120+ set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
121+ }
122+ }
123+ }
124+ }
125+
56126template <typename T>
57127void matmul_double_grad (const Tensor& x,
58128 const Tensor& y,
0 commit comments