Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions paddle/phi/kernels/gpu/p_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,26 @@

#include "paddle/phi/kernels/p_norm_grad_kernel.h"

#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
#include "paddle/phi/kernels/reduce_amax_grad_kernel.h"
#include "paddle/phi/kernels/sign_kernel.h"

namespace phi {

template <typename T>
struct AbsMaxAndMinGradFunctor {
template <typename Context,
typename X,
typename Y,
typename DX,
typename DY,
typename Dim>
void operator()(const Context& place,
X* x,
Y* y,
DX* dx,
DY* dy,
const Dim& dim,
int size) {
dx->device(place) = dy->broadcast(dim) * (*x).sign() *
((*x).abs() == y->broadcast(dim)).template cast<T>();
}
};

template <typename T>
struct PNormGradFunctor {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) {
this->porder = static_cast<MT>(porder - 1.);
this->porder = static_cast<MT>(porder - 1.0f);
this->eps = static_cast<MT>(eps);
}

Expand All @@ -61,29 +50,16 @@ struct PNormGradFunctor {
DY* dy,
const Dim& dim,
int size) {
auto x_mt = x->template cast<MT>();
auto y_mt = y->template cast<MT>();
auto dy_mt = dy->template cast<MT>();

auto norm_pow = y_mt.pow(-this->porder);
auto mask_norm_nonzero = (y_mt != static_cast<MT>(0)).template cast<MT>();

// Set to 0 where porder < 0 and x == 0
MT zero = static_cast<MT>(0);
auto mask_x_zero = (x_mt == zero).template cast<MT>();

MT is_porder_negative =
this->porder < zero ? static_cast<MT>(1) : static_cast<MT>(0);
auto invalid_mask = (mask_x_zero * is_porder_negative);
auto safe_pow =
x_mt.abs().pow(this->porder) * (static_cast<MT>(1) - invalid_mask);

auto unstable_term =
(*x).abs().template cast<MT>().pow(this->porder).template cast<T>();
auto mask = (*x) == x->constant(static_cast<T>(0));
auto stable_term =
mask.select(x->constant(static_cast<T>(0)), unstable_term);
auto self_scaled = (*x).sign() * stable_term;
auto norm_term =
(*y).template cast<MT>().pow(-this->porder).template cast<T>();
dx->device(place) =
(safe_pow * x_mt.sign() * dy_mt.broadcast(dim) *
norm_pow.broadcast(dim) *
mask_norm_nonzero.broadcast(dim) // Mask out positions where norm == 0
)
.template cast<T>();
self_scaled * dy->broadcast(dim) * norm_term.broadcast(dim);
}

MT porder;
Expand All @@ -109,24 +85,52 @@ void PNormGradKernel(const Context& dev_ctx,

auto xdim = in_x->dims();
bool reduce_all = (in_norm->numel() == 1);
if (axis < 0) axis = xdim.size() + axis;
if (axis < 0) {
axis = xdim.size() + axis;
}
const std::vector<int> dims = {axis};

if (porder == 0) {
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, out_dx, static_cast<T>(0));
} else if (porder == INFINITY || porder == -INFINITY) {
AbsMaxAndMinGradFunctor<T> functor;
funcs::LaunchReduceGradKernel<Context, T, AbsMaxAndMinGradFunctor<T>>(
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
std::vector<int64_t> dims_for_amax;
if (reduce_all) {
dims_for_amax.resize(xdim.size());
for (int i = 0; i < xdim.size(); ++i) dims_for_amax[i] = i;
} else {
dims_for_amax.push_back(axis);
}

DenseTensor x_abs;
x_abs.Resize(in_x->dims());
dev_ctx.template Alloc<T>(&x_abs);
phi::AbsKernel<T, Context>(dev_ctx, *in_x, &x_abs);

DenseTensor amax_grad_out;
amax_grad_out.Resize(in_x->dims());
dev_ctx.template Alloc<T>(&amax_grad_out);
phi::ReduceAMaxGradKernel<T, Context>(dev_ctx,
x_abs,
*in_norm,
*in_norm_dy,
dims_for_amax,
keepdim,
reduce_all,
&amax_grad_out);
DenseTensor x_sign;
x_sign.Resize(in_x->dims());
dev_ctx.template Alloc<T>(&x_sign);
phi::SignKernel<T, Context>(dev_ctx, *in_x, &x_sign);
phi::MultiplyKernel<T, Context>(dev_ctx, amax_grad_out, x_sign, out_dx);
} else {
auto functor = PNormGradFunctor<T>(porder, epsilon);
funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>(
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
}
}
} // namespace phi

PD_REGISTER_KERNEL(p_norm_grad,
GPU,
ALL_LAYOUT,
Expand Down
35 changes: 25 additions & 10 deletions paddle/phi/kernels/gpu/p_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce.h"

#include "paddle/phi/kernels/activation_kernel.h"

namespace phi {
template <typename T>
struct NonzeroFunctor {
Expand Down Expand Up @@ -132,10 +134,26 @@ void PNormKernel(const Context& dev_ctx,
// fast 1-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsFunctor<T>(), reduce_axis);
return;
} else if (porder == 2.0) {
// fast 2-norm
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, SquareFunctor<MT>>(
dev_ctx, *in_x, &out_temp, SquareFunctor<MT>(), reduce_axis);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
phi::DenseTensor temp_sum_of_squares_hp;
temp_sum_of_squares_hp.Resize(out_norm->dims());
dev_ctx.template Alloc<MT>(&temp_sum_of_squares_hp);
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, SquareFunctor<T>>(
dev_ctx,
*in_x,
&temp_sum_of_squares_hp,
SquareFunctor<T>(),
reduce_axis);

phi::DenseTensor temp_norm_hp;
temp_norm_hp.Resize(out_norm->dims());
dev_ctx.template Alloc<MT>(&temp_norm_hp);
phi::SqrtKernel<MT>(dev_ctx, temp_sum_of_squares_hp, &temp_norm_hp);
phi::CastKernel<MT>(dev_ctx, temp_norm_hp, out_norm->dtype(), out_norm);
return;
} else if (porder == 3.0) {
// fast 3-norm
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, FabsCubicFunctor<MT>>(
Expand All @@ -149,14 +167,11 @@ void PNormKernel(const Context& dev_ctx,
UnsignedPowFunctor<MT>(porder),
reduce_axis);
}

if (porder != 1.0) {
std::vector<const DenseTensor*> ins = {&out_temp};
std::vector<DenseTensor*> outs = {out_norm};
MT p_order_ = static_cast<MT>(1.f / porder);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<MT>(p_order_));
}
std::vector<const DenseTensor*> ins = {&out_temp};
std::vector<DenseTensor*> outs = {out_norm};
MT p_order_ = static_cast<MT>(1.f / porder);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<MT>(p_order_));
#endif
}
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ PD_REGISTER_KERNEL(amax_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(amin_grad,
GPU,
Expand Down