Skip to content

Commit d1f6fee

Browse files
fix bug:vector_norm test=develop
1 parent befa2b5 commit d1f6fee

File tree

2 files changed

+45
-25
lines changed

2 files changed

+45
-25
lines changed

paddle/phi/kernels/gpu/p_norm_grad_kernel.cu

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,21 @@
1414

1515
#include "paddle/phi/kernels/p_norm_grad_kernel.h"
1616

17+
#include <vector>
18+
19+
#include "paddle/phi/backends/gpu/gpu_context.h"
1720
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/kernels/abs_kernel.h"
22+
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
23+
#include "paddle/phi/kernels/funcs/eigen/common.h"
24+
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
1825
#include "paddle/phi/kernels/funcs/math_function.h"
1926
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
27+
#include "paddle/phi/kernels/reduce_amax_grad_kernel.h"
28+
#include "paddle/phi/kernels/sign_kernel.h"
2029

2130
namespace phi {
2231

23-
template <typename T>
24-
struct AbsMaxAndMinGradFunctor {
25-
template <typename Context,
26-
typename X,
27-
typename Y,
28-
typename DX,
29-
typename DY,
30-
typename Dim>
31-
void operator()(const Context& place,
32-
X* x,
33-
Y* y,
34-
DX* dx,
35-
DY* dy,
36-
const Dim& dim,
37-
int size) {
38-
dx->device(place) = dy->broadcast(dim) * (*x).sign() *
39-
((*x).abs() == y->broadcast(dim)).template cast<T>();
40-
}
41-
};
42-
4332
template <typename T>
4433
struct PNormGradFunctor {
4534
HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) {
@@ -86,24 +75,53 @@ void PNormGradKernel(const Context& dev_ctx,
8675

8776
auto xdim = in_x->dims();
8877
bool reduce_all = (in_norm->numel() == 1);
89-
if (axis < 0) axis = xdim.size() + axis;
78+
if (axis < 0) {
79+
axis = xdim.size() + axis;
80+
}
9081
const std::vector<int> dims = {axis};
9182

9283
if (porder == 0) {
9384
phi::funcs::SetConstant<Context, T> set_zero;
9485
set_zero(dev_ctx, out_dx, static_cast<T>(0));
9586
} else if (porder == INFINITY || porder == -INFINITY) {
96-
AbsMaxAndMinGradFunctor<T> functor;
97-
funcs::LaunchReduceGradKernel<Context, T, AbsMaxAndMinGradFunctor<T>>(
98-
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
87+
std::vector<int64_t> dims_for_amax;
88+
if (reduce_all) {
89+
dims_for_amax.resize(xdim.size());
90+
for (int i = 0; i < xdim.size(); ++i) dims_for_amax[i] = i;
91+
} else {
92+
dims_for_amax.push_back(axis);
93+
}
94+
95+
DenseTensor x_abs;
96+
x_abs.Resize(in_x->dims());
97+
dev_ctx.template Alloc<T>(&x_abs);
98+
phi::AbsKernel<T, Context>(dev_ctx, *in_x, &x_abs);
9999

100+
DenseTensor amax_grad_out;
101+
amax_grad_out.Resize(in_x->dims());
102+
dev_ctx.template Alloc<T>(&amax_grad_out);
103+
phi::ReduceAMaxGradKernel<T, Context>(dev_ctx,
104+
x_abs,
105+
*in_norm,
106+
*in_norm_dy,
107+
dims_for_amax,
108+
keepdim,
109+
reduce_all,
110+
&amax_grad_out);
111+
DenseTensor x_sign;
112+
x_sign.Resize(in_x->dims());
113+
dev_ctx.template Alloc<T>(&x_sign);
114+
phi::SignKernel<T, Context>(dev_ctx, *in_x, &x_sign);
115+
116+
phi::MultiplyKernel<T, Context>(dev_ctx, amax_grad_out, x_sign, out_dx);
100117
} else {
101118
auto functor = PNormGradFunctor<T>(porder, epsilon);
102119
funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>(
103120
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
104121
}
105122
}
106123
} // namespace phi
124+
107125
PD_REGISTER_KERNEL(p_norm_grad,
108126
GPU,
109127
ALL_LAYOUT,

paddle/phi/kernels/gpu/reduce_kernel.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ PD_REGISTER_KERNEL(amax_grad,
262262
float,
263263
double,
264264
int,
265-
int64_t) {}
265+
int64_t,
266+
phi::dtype::float16,
267+
phi::dtype::bfloat16) {}
266268

267269
PD_REGISTER_KERNEL(amin_grad,
268270
GPU,

0 commit comments

Comments
 (0)