Skip to content

Commit 7d48cb8

Browse files
bugfix:p_norm test=develop
1 parent 65d8ed7 commit 7d48cb8

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

paddle/phi/kernels/gpu/p_norm_grad_kernel.cu

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,14 @@ struct PNormGradFunctor {
5050
DY* dy,
5151
const Dim& dim,
5252
int size) {
53-
auto x_mt = x->template cast<MT>();
54-
auto y_mt = y->template cast<MT>();
55-
auto dy_mt = dy->template cast<MT>();
56-
57-
auto norm_pow = y_mt.pow(-this->porder);
58-
auto mask_norm_nonzero = (y_mt != static_cast<MT>(0)).template cast<MT>();
59-
60-
// Set to 0 where porder < 0 and x == 0
61-
MT zero = static_cast<MT>(0);
62-
auto mask_x_zero = (x_mt == zero).template cast<MT>();
63-
64-
MT is_porder_negative =
65-
this->porder < zero ? static_cast<MT>(1) : static_cast<MT>(0);
66-
auto invalid_mask = (mask_x_zero * is_porder_negative);
67-
auto safe_pow =
68-
x_mt.abs().pow(this->porder) * (static_cast<MT>(1) - invalid_mask);
69-
53+
auto unstable_term = (*x).abs().pow(this->porder);
54+
auto mask = (*x) == x->constant(static_cast<T>(0));
55+
auto stable_term =
56+
mask.select(x->constant(static_cast<T>(0)), unstable_term);
57+
auto self_scaled = (*x).sign() * stable_term;
58+
auto norm_term = (*y).pow(-this->porder);
7059
dx->device(place) =
71-
(safe_pow * x_mt.sign() * dy_mt.broadcast(dim) *
72-
norm_pow.broadcast(dim) *
73-
mask_norm_nonzero.broadcast(dim) // Mask out positions where norm == 0
74-
)
75-
.template cast<T>();
60+
self_scaled * dy->broadcast(dim) * norm_term.broadcast(dim);
7661
}
7762

7863
MT porder;

paddle/phi/kernels/gpu/p_norm_kernel.cu

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#include "paddle/phi/kernels/funcs/reduce_function.h"
2323
#include "paddle/phi/kernels/gpu/reduce.h"
2424

25+
#include "paddle/fluid/framework/tensor_util.h"
26+
#include "paddle/phi/kernels/activation_kernel.h"
27+
2528
namespace phi {
2629
template <typename T>
2730
struct NonzeroFunctor {
@@ -134,8 +137,24 @@ void PNormKernel(const Context& dev_ctx,
134137
dev_ctx, *in_x, out_norm, FabsFunctor<T>(), reduce_axis);
135138
} else if (porder == 2.0) {
136139
// fast 2-norm
137-
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, SquareFunctor<MT>>(
138-
dev_ctx, *in_x, &out_temp, SquareFunctor<MT>(), reduce_axis);
140+
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
141+
phi::DenseTensor temp_sum_of_squares_hp;
142+
temp_sum_of_squares_hp.Resize(out_norm->dims());
143+
dev_ctx.template Alloc<MT>(&temp_sum_of_squares_hp);
144+
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, SquareFunctor<T>>(
145+
dev_ctx,
146+
*in_x,
147+
&temp_sum_of_squares_hp,
148+
SquareFunctor<T>(),
149+
reduce_axis);
150+
151+
phi::DenseTensor temp_norm_hp;
152+
temp_norm_hp.Resize(out_norm->dims());
153+
dev_ctx.template Alloc<MT>(&temp_norm_hp);
154+
phi::SqrtKernel<MT>(dev_ctx, temp_sum_of_squares_hp, &temp_norm_hp);
155+
phi::CastKernel<MT>(dev_ctx, temp_norm_hp, out_norm->dtype(), out_norm);
156+
return;
157+
139158
} else if (porder == 3.0) {
140159
// fast 3-norm
141160
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, FabsCubicFunctor<MT>>(

0 commit comments

Comments
 (0)