File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -167,6 +167,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
167167 auto sign =
168168 (x_minux_y > static_cast <T>(0 )).template cast <T>() * static_cast <T>(1.0 ) +
169169 (x_minux_y < static_cast <T>(0 )).template cast <T>() * static_cast <T>(-1.0 );
170+ T epsilon = static_cast <T>(1 .0e-10f );
170171
171172 // 1: Lp-norm(z), z = x-y, compute dz
172173 if (p == 0 ) {
@@ -189,12 +190,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
189190 // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
190191 if (platform::is_cpu_place (context.GetPlace ())) {
191192 grad_t .device (place) =
192- (x_minux_y_abs / out_t .broadcast (out_bcast_dims)).pow (p - 1 ) *
193+ (x_minux_y_abs / (out_t + epsilon).broadcast (out_bcast_dims))
194+ .pow (p - 1 ) *
193195 sign.eval () * out_grad_t .broadcast (out_bcast_dims);
194196 } else {
195197 grad_t .device (place) =
196- (x_minux_y_abs / out_t .broadcast (out_bcast_dims)).pow (p - 1 ) * sign *
197- out_grad_t .broadcast (out_bcast_dims);
198+ (x_minux_y_abs / (out_t + epsilon).broadcast (out_bcast_dims))
199+ .pow (p - 1 ) *
200+ sign * out_grad_t .broadcast (out_bcast_dims);
198201 }
199202 }
200203
You can’t perform that action at this time.
0 commit comments