Skip to content

Commit 304bb3d

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Fix for Keras Softmax layer gradient underflow.
tensorflow/tensorflow#60314 The `tf.keras.activations.softmax` function, the `tf.keras.backend.softmax` function and the `tf.keras.layers.Softmax` layer now behave consistently and save the logits in `_keras_logits`. Previously, only the activation function had this behavior. This prevents the computation of the gradient of the crossentropy from underflowing. The same fix was applied to the `tf.keras.backend.sigmoid` function and the `tf.keras.layers.Sigmoid` layer. One behavior change is that `tf.keras.backend.softmax` and `tf.keras.layers.Softmax` no longer accept inputs of rank 1. PiperOrigin-RevId: 536456175
1 parent ddf134e commit 304bb3d

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

keras/activations.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,7 @@ def softmax(x, axis=-1):
8484
>>> layer = tf.keras.layers.Dense(32,
8585
... activation=tf.keras.activations.softmax)
8686
"""
87-
if x.shape.rank <= 1:
88-
raise ValueError(
89-
f"Cannot apply softmax to a tensor that is 1D. Received input: {x}"
90-
)
91-
92-
if isinstance(axis, int):
93-
output = tf.nn.softmax(x, axis=axis)
94-
else:
95-
# nn.softmax does not support tuple axis.
96-
numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))
97-
denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True)
98-
output = numerator / denominator
99-
100-
# Cache the logits to use for crossentropy loss.
101-
output._keras_logits = x
102-
return output
87+
return backend.softmax(x, axis)
10388

10489

10590
@keras_export("keras.activations.elu")
@@ -412,10 +397,7 @@ def sigmoid(x):
412397
Returns:
413398
Tensor with the sigmoid activation: `1 / (1 + exp(-x))`.
414399
"""
415-
output = tf.sigmoid(x)
416-
# Cache the logits to use for crossentropy loss.
417-
output._keras_logits = x
418-
return output
400+
return backend.sigmoid(x)
419401

420402

421403
@keras_export("keras.activations.exponential")

keras/backend.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5441,7 +5441,22 @@ def softmax(x, axis=-1):
54415441
Returns:
54425442
A tensor.
54435443
"""
5444-
return tf.nn.softmax(x, axis=axis)
5444+
if x.shape.rank <= 1:
5445+
raise ValueError(
5446+
f"Cannot apply softmax to a tensor that is 1D. Received input: {x}"
5447+
)
5448+
5449+
if isinstance(axis, int):
5450+
output = tf.nn.softmax(x, axis=axis)
5451+
else:
5452+
# nn.softmax does not support tuple axis.
5453+
numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))
5454+
denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True)
5455+
output = numerator / denominator
5456+
5457+
# Cache the logits to use for crossentropy loss.
5458+
output._keras_logits = x
5459+
return output
54455460

54465461

54475462
@keras_export("keras.backend.softplus")
@@ -5899,7 +5914,10 @@ def sigmoid(x):
58995914
Returns:
59005915
A tensor.
59015916
"""
5902-
return tf.math.sigmoid(x)
5917+
output = tf.sigmoid(x)
5918+
# Cache the logits to use for crossentropy loss.
5919+
output._keras_logits = x
5920+
return output
59035921

59045922

59055923
@keras_export("keras.backend.hard_sigmoid")

keras/layers/activation/softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ class Softmax(Layer):
5151
5252
Example without mask:
5353
54-
>>> inp = np.asarray([1., 2., 1.])
54+
>>> inp = np.asarray([[1., 2., 1.]])
5555
>>> layer = tf.keras.layers.Softmax()
5656
>>> layer(inp).numpy()
57-
array([0.21194157, 0.5761169 , 0.21194157], dtype=float32)
58-
>>> mask = np.asarray([True, False, True], dtype=bool)
57+
array([[0.21194157, 0.5761169 , 0.21194157]], dtype=float32)
58+
>>> mask = np.asarray([[True, False, True]], dtype=bool)
5959
>>> layer(inp, mask).numpy()
60-
array([0.5, 0. , 0.5], dtype=float32)
60+
array([[0.5, 0. , 0.5]], dtype=float32)
6161
6262
Input shape:
6363
Arbitrary. Use the keyword argument `input_shape`

0 commit comments

Comments
 (0)