File tree Expand file tree Collapse file tree 1 file changed +0
-44
lines changed Expand file tree Collapse file tree 1 file changed +0
-44
lines changed Original file line number Diff line number Diff line change @@ -66,50 +66,6 @@ struct PNormGradFunctor {
6666 MT eps;
6767};
6868
69- // template <typename Context,
70- // typename X,
71- // typename Y,
72- // typename DX,
73- // typename DY,
74- // typename Dim>
75- // void operator()(const Context& place,
76- // X* x,
77- // Y* y,
78- // DX* dx,
79- // DY* dy,
80- // const Dim& dim,
81- // int size) {
82- // auto x_mt = x->template cast<MT>();
83- // auto y_mt = y->template cast<MT>();
84- // auto dy_mt = dy->template cast<MT>();
85-
86- // auto norm_pow = y_mt.pow(-this->porder);
87- // auto mask_norm_nonzero = (y_mt != static_cast<MT>(0)).template
88- // cast<MT>();
89-
90- // // Set to 0 where porder < 0 and x == 0
91- // MT zero = static_cast<MT>(0);
92- // auto mask_x_zero = (x_mt == zero).template cast<MT>();
93-
94- // MT is_porder_negative =
95- // this->porder < zero ? static_cast<MT>(1) : static_cast<MT>(0);
96- // auto invalid_mask = (mask_x_zero * is_porder_negative);
97- // auto safe_pow =
98- // x_mt.abs().pow(this->porder) * (static_cast<MT>(1) - invalid_mask);
99-
100- // dx->device(place) =
101- // (safe_pow * x_mt.sign() * dy_mt.broadcast(dim) *
102- // norm_pow.broadcast(dim) *
103- // mask_norm_nonzero.broadcast(dim) // Mask out positions where norm
104- // == 0
105- // )
106- // .template cast<T>();
107- // }
108-
109- // MT porder;
110- // MT eps;
111- // };
112-
11369template <typename T, typename Context>
11470void PNormGradKernel (const Context& dev_ctx,
11571 const DenseTensor& x,
You can’t perform that action at this time.
0 commit comments