Skip to content

Commit 65d8ed7

Browse files
fix bug:vector_norm test=develop
1 parent b5992db commit 65d8ed7

File tree

2 files changed

+45
-25
lines changed

2 files changed

+45
-25
lines changed

paddle/phi/kernels/gpu/p_norm_grad_kernel.cu

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,21 @@
1414

1515
#include "paddle/phi/kernels/p_norm_grad_kernel.h"
1616

17+
#include <vector>
18+
19+
#include "paddle/phi/backends/gpu/gpu_context.h"
1720
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/kernels/abs_kernel.h"
22+
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
23+
#include "paddle/phi/kernels/funcs/eigen/common.h"
24+
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
1825
#include "paddle/phi/kernels/funcs/math_function.h"
1926
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
27+
#include "paddle/phi/kernels/reduce_amax_grad_kernel.h"
28+
#include "paddle/phi/kernels/sign_kernel.h"
2029

2130
namespace phi {
2231

23-
template <typename T>
24-
struct AbsMaxAndMinGradFunctor {
25-
template <typename Context,
26-
typename X,
27-
typename Y,
28-
typename DX,
29-
typename DY,
30-
typename Dim>
31-
void operator()(const Context& place,
32-
X* x,
33-
Y* y,
34-
DX* dx,
35-
DY* dy,
36-
const Dim& dim,
37-
int size) {
38-
dx->device(place) = dy->broadcast(dim) * (*x).sign() *
39-
((*x).abs() == y->broadcast(dim)).template cast<T>();
40-
}
41-
};
42-
4332
template <typename T>
4433
struct PNormGradFunctor {
4534
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -109,24 +98,53 @@ void PNormGradKernel(const Context& dev_ctx,
10998

11099
auto xdim = in_x->dims();
111100
bool reduce_all = (in_norm->numel() == 1);
112-
if (axis < 0) axis = xdim.size() + axis;
101+
if (axis < 0) {
102+
axis = xdim.size() + axis;
103+
}
113104
const std::vector<int> dims = {axis};
114105

115106
if (porder == 0) {
116107
phi::funcs::SetConstant<Context, T> set_zero;
117108
set_zero(dev_ctx, out_dx, static_cast<T>(0));
118109
} else if (porder == INFINITY || porder == -INFINITY) {
119-
AbsMaxAndMinGradFunctor<T> functor;
120-
funcs::LaunchReduceGradKernel<Context, T, AbsMaxAndMinGradFunctor<T>>(
121-
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
110+
std::vector<int64_t> dims_for_amax;
111+
if (reduce_all) {
112+
dims_for_amax.resize(xdim.size());
113+
for (int i = 0; i < xdim.size(); ++i) dims_for_amax[i] = i;
114+
} else {
115+
dims_for_amax.push_back(axis);
116+
}
117+
118+
DenseTensor x_abs;
119+
x_abs.Resize(in_x->dims());
120+
dev_ctx.template Alloc<T>(&x_abs);
121+
phi::AbsKernel<T, Context>(dev_ctx, *in_x, &x_abs);
122122

123+
DenseTensor amax_grad_out;
124+
amax_grad_out.Resize(in_x->dims());
125+
dev_ctx.template Alloc<T>(&amax_grad_out);
126+
phi::ReduceAMaxGradKernel<T, Context>(dev_ctx,
127+
x_abs,
128+
*in_norm,
129+
*in_norm_dy,
130+
dims_for_amax,
131+
keepdim,
132+
reduce_all,
133+
&amax_grad_out);
134+
DenseTensor x_sign;
135+
x_sign.Resize(in_x->dims());
136+
dev_ctx.template Alloc<T>(&x_sign);
137+
phi::SignKernel<T, Context>(dev_ctx, *in_x, &x_sign);
138+
139+
phi::MultiplyKernel<T, Context>(dev_ctx, amax_grad_out, x_sign, out_dx);
123140
} else {
124141
auto functor = PNormGradFunctor<T>(porder, epsilon);
125142
funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>(
126143
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
127144
}
128145
}
129146
} // namespace phi
147+
130148
PD_REGISTER_KERNEL(p_norm_grad,
131149
GPU,
132150
ALL_LAYOUT,

paddle/phi/kernels/gpu/reduce_kernel.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ PD_REGISTER_KERNEL(amax_grad,
262262
float,
263263
double,
264264
int,
265-
int64_t) {}
265+
int64_t,
266+
phi::dtype::float16,
267+
phi::dtype::bfloat16) {}
266268

267269
PD_REGISTER_KERNEL(amin_grad,
268270
GPU,

0 commit comments

Comments
 (0)