Skip to content

Commit b79b849

Browse files
marcocuturiclaude
andauthored
Remove activation on final layer of KeyNet and ICNN (#702)
* Remove activation on final layer of KeyNet and ICNN The output layer of both KeyNet.gradient and ICNN.__call__ previously applied the activation function (default ReLU) after the final layer. This forced the outputs to be non-negative: KeyNet's predicted vectors could not take signed values, and ICNN's scalar potential was clamped to be non-negative. Make the final layer linear in both networks. Convexity of the ICNN output is preserved (a non-negatively weighted combination of convex features remains convex). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * fix: yapf formatting + linen potential in cmonge_gap test Reformat the final-layer loop in ICNN/KeyNet to satisfy yapf (the CI "code" Lint check). Switch conditional_monge_gap_test to LinenPotentialMLP so it uses the linen init/apply API it was written for, matching monge_gap_test (the nnx PotentialMLP now requires input_dim/rngs). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent f959b2a commit b79b849

2 files changed

Lines changed: 17 additions & 11 deletions

File tree

src/ott/neural/networks/icnn.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,14 @@ def __call__(self, x: jax.Array) -> jax.Array:
225225

226226
z = self._act_fn_call(self.wx0(x))
227227

228-
for wx, wz in zip(self.wx_layers, self.wz_layers, strict=True):
229-
if wx is not None:
230-
z = self._act_fn_call(wz(z) + wx(x))
231-
else:
232-
z = self._act_fn_call(wz(z))
228+
num_layers = len(self.wz_layers)
229+
for i, (wx,
230+
wz) in enumerate(zip(self.wx_layers, self.wz_layers, strict=True)):
231+
z = wz(z) + wx(x) if wx is not None else wz(z)
232+
# The final layer is linear: no activation, so the (convex) potential
233+
# is an unconstrained combination of the last hidden features.
234+
if i != num_layers - 1:
235+
z = self._act_fn_call(z)
233236

234237
if self.pos_def_potentials is not None:
235238
z = z + self.pos_def_potentials(x)
@@ -399,11 +402,14 @@ def gradient(self, x: jax.Array) -> jax.Array:
399402
batch_size, _ = x.shape
400403
z = self._act_fn_call(self.wx0(x))
401404

402-
for wx, wz in zip(self.wx_layers, self.wz_layers, strict=True):
403-
if wx is not None:
404-
z = self._act_fn_call(wz(z) + wx(x))
405-
else:
406-
z = self._act_fn_call(wz(z))
405+
num_layers = len(self.wz_layers)
406+
for i, (wx,
407+
wz) in enumerate(zip(self.wx_layers, self.wz_layers, strict=True)):
408+
z = wz(z) + wx(x) if wx is not None else wz(z)
409+
# The final layer is linear: no activation, so the vector output can
410+
# take arbitrary values (e.g. signed gradients).
411+
if i != num_layers - 1:
412+
z = self._act_fn_call(z)
407413

408414
if self._resnet:
409415
z = x + z

tests/neural/methods/conditional_monge_gap_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_non_negativity_neural_map(
207207
rng1, rng2 = jax.random.split(rng)
208208

209209
source = jax.random.normal(rng1, (n, n_features))
210-
model = potentials.PotentialMLP(dim_hidden=[8, 8], is_potential=False)
210+
model = potentials.LinenPotentialMLP(dim_hidden=[8, 8], is_potential=False)
211211
params = model.init(rng2, x=source[0])
212212
target = model.apply(params, source)
213213
condition = jnp.repeat(jnp.arange(k), per_cond)

0 commit comments

Comments
 (0)