|
14 | 14 |
|
15 | 15 | #include "paddle/phi/kernels/p_norm_grad_kernel.h" |
16 | 16 |
|
| 17 | +#include <vector> |
| 18 | + |
| 19 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
17 | 20 | #include "paddle/phi/core/kernel_registry.h" |
| 21 | +#include "paddle/phi/kernels/abs_kernel.h" |
| 22 | +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" |
| 23 | +#include "paddle/phi/kernels/funcs/eigen/common.h" |
| 24 | +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" |
18 | 25 | #include "paddle/phi/kernels/funcs/math_function.h" |
19 | 26 | #include "paddle/phi/kernels/funcs/reduce_grad_functions.h" |
| 27 | +#include "paddle/phi/kernels/reduce_amax_grad_kernel.h" |
| 28 | +#include "paddle/phi/kernels/sign_kernel.h" |
20 | 29 |
|
21 | 30 | namespace phi { |
22 | 31 |
|
23 | | -template <typename T> |
24 | | -struct AbsMaxAndMinGradFunctor { |
25 | | - template <typename Context, |
26 | | - typename X, |
27 | | - typename Y, |
28 | | - typename DX, |
29 | | - typename DY, |
30 | | - typename Dim> |
31 | | - void operator()(const Context& place, |
32 | | - X* x, |
33 | | - Y* y, |
34 | | - DX* dx, |
35 | | - DY* dy, |
36 | | - const Dim& dim, |
37 | | - int size) { |
38 | | - dx->device(place) = dy->broadcast(dim) * (*x).sign() * |
39 | | - ((*x).abs() == y->broadcast(dim)).template cast<T>(); |
40 | | - } |
41 | | -}; |
42 | | - |
43 | 32 | template <typename T> |
44 | 33 | struct PNormGradFunctor { |
45 | 34 | using MT = typename phi::dtype::MPTypeTrait<T>::Type; |
@@ -109,24 +98,53 @@ void PNormGradKernel(const Context& dev_ctx, |
109 | 98 |
|
110 | 99 | auto xdim = in_x->dims(); |
111 | 100 | bool reduce_all = (in_norm->numel() == 1); |
112 | | - if (axis < 0) axis = xdim.size() + axis; |
| 101 | + if (axis < 0) { |
| 102 | + axis = xdim.size() + axis; |
| 103 | + } |
113 | 104 | const std::vector<int> dims = {axis}; |
114 | 105 |
|
115 | 106 | if (porder == 0) { |
116 | 107 | phi::funcs::SetConstant<Context, T> set_zero; |
117 | 108 | set_zero(dev_ctx, out_dx, static_cast<T>(0)); |
118 | 109 | } else if (porder == INFINITY || porder == -INFINITY) { |
119 | | - AbsMaxAndMinGradFunctor<T> functor; |
120 | | - funcs::LaunchReduceGradKernel<Context, T, AbsMaxAndMinGradFunctor<T>>( |
121 | | - dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); |
| 110 | + std::vector<int64_t> dims_for_amax; |
| 111 | + if (reduce_all) { |
| 112 | + dims_for_amax.resize(xdim.size()); |
| 113 | + for (int i = 0; i < xdim.size(); ++i) dims_for_amax[i] = i; |
| 114 | + } else { |
| 115 | + dims_for_amax.push_back(axis); |
| 116 | + } |
| 117 | + |
| 118 | + DenseTensor x_abs; |
| 119 | + x_abs.Resize(in_x->dims()); |
| 120 | + dev_ctx.template Alloc<T>(&x_abs); |
| 121 | + phi::AbsKernel<T, Context>(dev_ctx, *in_x, &x_abs); |
122 | 122 |
|
| 123 | + DenseTensor amax_grad_out; |
| 124 | + amax_grad_out.Resize(in_x->dims()); |
| 125 | + dev_ctx.template Alloc<T>(&amax_grad_out); |
| 126 | + phi::ReduceAMaxGradKernel<T, Context>(dev_ctx, |
| 127 | + x_abs, |
| 128 | + *in_norm, |
| 129 | + *in_norm_dy, |
| 130 | + dims_for_amax, |
| 131 | + keepdim, |
| 132 | + reduce_all, |
| 133 | + &amax_grad_out); |
| 134 | + DenseTensor x_sign; |
| 135 | + x_sign.Resize(in_x->dims()); |
| 136 | + dev_ctx.template Alloc<T>(&x_sign); |
| 137 | + phi::SignKernel<T, Context>(dev_ctx, *in_x, &x_sign); |
| 138 | + |
| 139 | + phi::MultiplyKernel<T, Context>(dev_ctx, amax_grad_out, x_sign, out_dx); |
123 | 140 | } else { |
124 | 141 | auto functor = PNormGradFunctor<T>(porder, epsilon); |
125 | 142 | funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>( |
126 | 143 | dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); |
127 | 144 | } |
128 | 145 | } |
129 | 146 | } // namespace phi |
| 147 | + |
130 | 148 | PD_REGISTER_KERNEL(p_norm_grad, |
131 | 149 | GPU, |
132 | 150 | ALL_LAYOUT, |
|
0 commit comments