Skip to content

Commit 38cc403

Browse files
bugfix:p_norm test=develop
1 parent 963ffcc commit 38cc403

File tree

1 file changed

+0
-44
lines changed

1 file changed

+0
-44
lines changed

paddle/phi/kernels/gpu/p_norm_grad_kernel.cu

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -66,50 +66,6 @@ struct PNormGradFunctor {
6666
MT eps;
6767
};
6868

69-
// template <typename Context,
70-
// typename X,
71-
// typename Y,
72-
// typename DX,
73-
// typename DY,
74-
// typename Dim>
75-
// void operator()(const Context& place,
76-
// X* x,
77-
// Y* y,
78-
// DX* dx,
79-
// DY* dy,
80-
// const Dim& dim,
81-
// int size) {
82-
// auto x_mt = x->template cast<MT>();
83-
// auto y_mt = y->template cast<MT>();
84-
// auto dy_mt = dy->template cast<MT>();
85-
86-
// auto norm_pow = y_mt.pow(-this->porder);
87-
// auto mask_norm_nonzero = (y_mt != static_cast<MT>(0)).template
88-
// cast<MT>();
89-
90-
// // Set to 0 where porder < 0 and x == 0
91-
// MT zero = static_cast<MT>(0);
92-
// auto mask_x_zero = (x_mt == zero).template cast<MT>();
93-
94-
// MT is_porder_negative =
95-
// this->porder < zero ? static_cast<MT>(1) : static_cast<MT>(0);
96-
// auto invalid_mask = (mask_x_zero * is_porder_negative);
97-
// auto safe_pow =
98-
// x_mt.abs().pow(this->porder) * (static_cast<MT>(1) - invalid_mask);
99-
100-
// dx->device(place) =
101-
// (safe_pow * x_mt.sign() * dy_mt.broadcast(dim) *
102-
// norm_pow.broadcast(dim) *
103-
// mask_norm_nonzero.broadcast(dim) // Mask out positions where norm
104-
// == 0
105-
// )
106-
// .template cast<T>();
107-
// }
108-
109-
// MT porder;
110-
// MT eps;
111-
// };
112-
11369
template <typename T, typename Context>
11470
void PNormGradKernel(const Context& dev_ctx,
11571
const DenseTensor& x,

0 commit comments

Comments
 (0)