1717from typing import Callable
1818
1919import jax
20- from jax .nn import softplus
2120import jax .numpy as jnp
22- from jax .scipy .special import logsumexp
23- from jaxopt ._src .projection import projection_simplex , projection_hypercube
21+ from jaxopt ._src .projection import projection_simplex
22+
23+ from optax import losses as optax_losses
2424
2525
2626# Regression
@@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
3939 References:
4040 https://en.wikipedia.org/wiki/Huber_loss
4141 """
42- abs_diff = jnp .abs (target - pred )
43- return jnp .where (abs_diff > delta ,
44- delta * (abs_diff - .5 * delta ),
45- 0.5 * abs_diff ** 2 )
42+ return optax_losses .huber_loss (pred , target , delta )
4643
4744# Binary classification.
4845
@@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float:
5653 Returns:
5754 loss value
5855 """
59- # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1].
60- # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba),
61- # where xlogx(proba) = proba * log(proba).
62- # Use -log sigmoid(logit) = softplus(-logit)
63- # and 1 - sigmoid(logit) = sigmoid(-logit).
64- return softplus (jnp .where (label , - logit , logit ))
56+ return optax_losses .sigmoid_binary_cross_entropy (
57+ jnp .asarray (logit ), jnp .asarray (label ))
6558
6659
6760def binary_sparsemax_loss (label : int , logit : float ) -> float :
@@ -77,59 +70,8 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
7770 Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
7871 Vlad Niculae. JMLR 2020. (Sec. 4.4)
7972 """
80- return sparse_plus (jnp .where (label , - logit , logit ))
81-
82-
83- def sparse_plus (x : float ) -> float :
84- r"""Sparse plus function.
85-
86- Computes the function:
87-
88- .. math::
89-
90- \mathrm{sparse\_plus}(x) = \begin{cases}
91- 0, & x \leq -1\\
92- \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
93- x, & 1 \leq x
94- \end{cases}
95-
96- This is the twin function of the softplus activation ensuring a zero output
97- for inputs less than -1 and a linear output for inputs greater than 1,
98- while remaining smooth, convex, monotonic by an adequate definition between
99- -1 and 1.
100-
101- Args:
102- x: input (float)
103- Returns:
104- sparse_plus(x) as defined above
105- """
106- return jnp .where (x <= - 1.0 , 0.0 , jnp .where (x >= 1.0 , x , (x + 1.0 )** 2 / 4 ))
107-
108-
109- def sparse_sigmoid (x : float ) -> float :
110- r"""Sparse sigmoid function.
111-
112- Computes the function:
113-
114- .. math::
115-
116- \mathrm{sparse\_sigmoid}(x) = \begin{cases}
117- 0, & x \leq -1\\
118- \frac{1}{2}(x+1), & -1 < x < 1 \\
119- 1, & 1 \leq x
120- \end{cases}
121-
122- This is the twin function of the sigmoid activation ensuring a zero output
123- for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
124- output for inputs between -1 and 1. This is the derivative of the sparse
125- plus function.
126-
127- Args:
128- x: input (float)
129- Returns:
130- sparse_sigmoid(x) as defined above
131- """
132- return 0.5 * projection_hypercube (x + 1.0 , 2.0 )
73+ return optax_losses .sparsemax_loss (
74+ jnp .asarray (logit ), jnp .asarray (label ))
13375
13476
13577def binary_hinge_loss (label : int , score : float ) -> float :
@@ -144,8 +86,7 @@ def binary_hinge_loss(label: int, score: float) -> float:
14486 References:
14587 https://en.wikipedia.org/wiki/Hinge_loss
14688 """
147- signed_label = 2.0 * label - 1.0
148- return jnp .maximum (0 , 1 - score * signed_label )
89+ return optax_losses .hinge_loss (score , 2.0 * label - 1.0 )
14990
15091
15192def binary_perceptron_loss (label : int , score : float ) -> float :
@@ -160,8 +101,7 @@ def binary_perceptron_loss(label: int, score: float) -> float:
160101 References:
161102 https://en.wikipedia.org/wiki/Perceptron
162103 """
163- signed_label = 2.0 * label - 1.0
164- return jnp .maximum (0 , - score * signed_label )
104+ return optax_losses .perceptron_loss (score , 2.0 * label - 1.0 )
165105
166106# Multiclass classification.
167107
@@ -175,13 +115,8 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float:
175115 Returns:
176116 loss value
177117 """
178- logits = jnp .asarray (logits )
179- # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex.
180- # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba))
181- # To avoid roundoff error, subtract target inside logsumexp.
182- # logsumexp(logits) - logits[y] = logsumexp(logits - logits[y])
183- logits = (logits - logits [label ]).at [label ].set (0.0 )
184- return logsumexp (logits )
118+ return optax_losses .softmax_cross_entropy_with_integer_labels (
119+ jnp .asarray (logits ), jnp .asarray (label ))
185120
186121
187122def multiclass_sparsemax_loss (label : int , scores : jnp .ndarray ) -> float :
@@ -272,5 +207,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]):
272207 """
273208
274209 def fy_loss (y_true , scores , * args , ** kwargs ):
275- return max_fun (scores , * args , ** kwargs ) - jnp .vdot (y_true , scores )
210+ return optax_losses .make_fenchel_young_loss (max_fun )(
211+ scores .ravel (), y_true .ravel (), * args , ** kwargs )
276212 return fy_loss
0 commit comments