|
14 | 14 |
|
15 | 15 | #include "paddle/phi/kernels/p_norm_grad_kernel.h" |
16 | 16 |
|
| 17 | +#include <vector> |
| 18 | + |
| 19 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
17 | 20 | #include "paddle/phi/core/kernel_registry.h" |
| 21 | +#include "paddle/phi/kernels/abs_kernel.h" |
| 22 | +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" |
| 23 | +#include "paddle/phi/kernels/funcs/eigen/common.h" |
| 24 | +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" |
18 | 25 | #include "paddle/phi/kernels/funcs/math_function.h" |
19 | 26 | #include "paddle/phi/kernels/funcs/reduce_grad_functions.h" |
| 27 | +#include "paddle/phi/kernels/reduce_amax_grad_kernel.h" |
| 28 | +#include "paddle/phi/kernels/sign_kernel.h" |
20 | 29 |
|
21 | 30 | namespace phi { |
22 | 31 |
|
23 | | -template <typename T> |
24 | | -struct AbsMaxAndMinGradFunctor { |
25 | | - template <typename Context, |
26 | | - typename X, |
27 | | - typename Y, |
28 | | - typename DX, |
29 | | - typename DY, |
30 | | - typename Dim> |
31 | | - void operator()(const Context& place, |
32 | | - X* x, |
33 | | - Y* y, |
34 | | - DX* dx, |
35 | | - DY* dy, |
36 | | - const Dim& dim, |
37 | | - int size) { |
38 | | - dx->device(place) = dy->broadcast(dim) * (*x).sign() * |
39 | | - ((*x).abs() == y->broadcast(dim)).template cast<T>(); |
40 | | - } |
41 | | -}; |
42 | | - |
43 | 32 | template <typename T> |
44 | 33 | struct PNormGradFunctor { |
45 | 34 | HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) { |
@@ -86,24 +75,53 @@ void PNormGradKernel(const Context& dev_ctx, |
86 | 75 |
|
87 | 76 | auto xdim = in_x->dims(); |
88 | 77 | bool reduce_all = (in_norm->numel() == 1); |
89 | | - if (axis < 0) axis = xdim.size() + axis; |
| 78 | + if (axis < 0) { |
| 79 | + axis = xdim.size() + axis; |
| 80 | + } |
90 | 81 | const std::vector<int> dims = {axis}; |
91 | 82 |
|
92 | 83 | if (porder == 0) { |
93 | 84 | phi::funcs::SetConstant<Context, T> set_zero; |
94 | 85 | set_zero(dev_ctx, out_dx, static_cast<T>(0)); |
95 | 86 | } else if (porder == INFINITY || porder == -INFINITY) { |
96 | | - AbsMaxAndMinGradFunctor<T> functor; |
97 | | - funcs::LaunchReduceGradKernel<Context, T, AbsMaxAndMinGradFunctor<T>>( |
98 | | - dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); |
| 87 | + std::vector<int64_t> dims_for_amax; |
| 88 | + if (reduce_all) { |
| 89 | + dims_for_amax.resize(xdim.size()); |
| 90 | + for (int i = 0; i < xdim.size(); ++i) dims_for_amax[i] = i; |
| 91 | + } else { |
| 92 | + dims_for_amax.push_back(axis); |
| 93 | + } |
| 94 | + |
| 95 | + DenseTensor x_abs; |
| 96 | + x_abs.Resize(in_x->dims()); |
| 97 | + dev_ctx.template Alloc<T>(&x_abs); |
| 98 | + phi::AbsKernel<T, Context>(dev_ctx, *in_x, &x_abs); |
99 | 99 |
|
| 100 | + DenseTensor amax_grad_out; |
| 101 | + amax_grad_out.Resize(in_x->dims()); |
| 102 | + dev_ctx.template Alloc<T>(&amax_grad_out); |
| 103 | + phi::ReduceAMaxGradKernel<T, Context>(dev_ctx, |
| 104 | + x_abs, |
| 105 | + *in_norm, |
| 106 | + *in_norm_dy, |
| 107 | + dims_for_amax, |
| 108 | + keepdim, |
| 109 | + reduce_all, |
| 110 | + &amax_grad_out); |
| 111 | + DenseTensor x_sign; |
| 112 | + x_sign.Resize(in_x->dims()); |
| 113 | + dev_ctx.template Alloc<T>(&x_sign); |
| 114 | + phi::SignKernel<T, Context>(dev_ctx, *in_x, &x_sign); |
| 115 | + |
| 116 | + phi::MultiplyKernel<T, Context>(dev_ctx, amax_grad_out, x_sign, out_dx); |
100 | 117 | } else { |
101 | 118 | auto functor = PNormGradFunctor<T>(porder, epsilon); |
102 | 119 | funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>( |
103 | 120 | dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); |
104 | 121 | } |
105 | 122 | } |
106 | 123 | } // namespace phi |
| 124 | + |
107 | 125 | PD_REGISTER_KERNEL(p_norm_grad, |
108 | 126 | GPU, |
109 | 127 | ALL_LAYOUT, |
|
0 commit comments