Skip to content

Commit 1896c77

Browse files
fix gradient(nan) when two inputs are equal (#32448)
1 parent 727b28d commit 1896c77

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

paddle/fluid/operators/dist_op.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
167167
auto sign =
168168
(x_minux_y > static_cast<T>(0)).template cast<T>() * static_cast<T>(1.0) +
169169
(x_minux_y < static_cast<T>(0)).template cast<T>() * static_cast<T>(-1.0);
170+
T epsilon = static_cast<T>(1.0e-10f);
170171

171172
// 1: Lp-norm(z), z = x-y, compute dz
172173
if (p == 0) {
@@ -189,12 +190,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
189190
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
190191
if (platform::is_cpu_place(context.GetPlace())) {
191192
grad_t.device(place) =
192-
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) *
193+
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
194+
.pow(p - 1) *
193195
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
194196
} else {
195197
grad_t.device(place) =
196-
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign *
197-
out_grad_t.broadcast(out_bcast_dims);
198+
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
199+
.pow(p - 1) *
200+
sign * out_grad_t.broadcast(out_bcast_dims);
198201
}
199202
}
200203

0 commit comments

Comments
 (0)