Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions optimistix/_solver/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def step(
) -> tuple[Y, _FixedPointState, Aux]:
new_y, aux = fn(y, args)
error = (y**ω - new_y**ω).ω
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
new_state = _FixedPointState(self.norm((error**ω / scale**ω).ω))
with jax.numpy_dtype_promotion("standard"):
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
new_state = _FixedPointState(self.norm((error**ω / scale**ω).ω))
return new_y, new_state, aux

def terminate(
Expand Down
4 changes: 3 additions & 1 deletion optimistix/_solver/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import cast, Generic, Union

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
Expand All @@ -25,7 +26,8 @@ class _Damped(eqx.Module, strict=True):

def __call__(self, y: PyTree[Array]):
residual = self.operator.mv(y)
damped = jtu.tree_map(lambda yi: jnp.sqrt(self.damping) * yi, y)
with jax.numpy_dtype_promotion("standard"):
damped = jtu.tree_map(lambda yi: jnp.sqrt(self.damping) * yi, y)
return residual, damped


Expand Down
3 changes: 2 additions & 1 deletion optimistix/_solver/newton_chord.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def step(
if lower is not None or upper is not None:
diff = (y**ω - new_y**ω).ω
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
diffsize = self.norm((diff**ω / scale**ω).ω)
with jax.numpy_dtype_promotion("standard"):
diffsize = self.norm((diff**ω / scale**ω).ω)
if self.cauchy_termination:
f_val = fx
else:
Expand Down
13 changes: 8 additions & 5 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def finite_difference_jvp(fn, primals, tangents, eps=None, **kwargs):
# done to a tolerance of 1e-8 or so: the primal pass is already noisy at about
# the scale of ε.
ε = eps
primals_ε = (ω(primals) + ε * ω(tangents)).ω
out_ε = fn(*primals_ε, **kwargs)
tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out)
with jax.numpy_dtype_promotion("standard"):
primals_ε = (ω(primals) + ε * ω(tangents)).ω
out_ε = fn(*primals_ε, **kwargs)
tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out)
# We actually return the perturbed primal.
# This should still be within all tolerance checks, and means that we have aceesss
# to both the true primal and the perturbed primal when debugging.
Expand Down Expand Up @@ -488,7 +489,8 @@ def _getsize(y: PyTree[Array]):
def _laplacian(y: PyTree[Array], dx: Scalar):
(y, unflatten) = jfu.ravel_pytree(y)
laplacian = jnp.zeros_like(y)
laplacian = laplacian.at[1:-1].set((y[2:] + y[1:-1] + y[:-2]) / dx)
with jax.numpy_dtype_promotion("standard"):
laplacian = laplacian.at[1:-1].set((y[2:] + y[1:-1] + y[:-2]) / dx)
return unflatten(y)


Expand All @@ -508,7 +510,8 @@ def _nonlinear_heat_pde_general(
const = args
stepsize = t1 - t0
f_val = ((1 - y**ω) * _laplacian(y, dx) ** ω).ω
return const * (y0**ω + 0.5 * stepsize * (f_val**ω + f0**ω)).ω
with jax.numpy_dtype_promotion("standard"):
return const * (y0**ω + 0.5 * stepsize * (f_val**ω + f0**ω)).ω


# Note that the midpoint methods below assume that `f` is autonomous.
Expand Down
11 changes: 8 additions & 3 deletions tests/test_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def test_fixed_point(solver, _fn, init, args):

@pytest.mark.parametrize("solver", _fp_solvers)
@pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args)
def test_fixed_point_jvp(getkey, solver, _fn, init, args):
@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
def test_fixed_point_jvp(getkey, solver, _fn, init, dtype, args):
args = jtu.tree_map(lambda x: x.astype(dtype), args)
init = jtu.tree_map(lambda x: x.astype(dtype), init)
atol = rtol = 1e-3
has_aux = random.choice([True, False])
if has_aux:
Expand All @@ -54,8 +57,10 @@ def test_fixed_point_jvp(getkey, solver, _fn, init, args):
fn = _fn

dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init)
t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
)

def fixed_point(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
Expand Down
11 changes: 8 additions & 3 deletions tests/test_root_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def root_find_problem(y, args):

@pytest.mark.parametrize("solver", _root_finders)
@pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args)
def test_root_find_jvp(getkey, solver, _fn, init, args):
@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
def test_root_find_jvp(getkey, solver, _fn, init, dtype, args):
args = jtu.tree_map(lambda x: x.astype(dtype), args)
init = jtu.tree_map(lambda x: x.astype(dtype), init)
atol = rtol = 1e-3
has_aux = random.choice([True, False])

Expand All @@ -65,8 +68,10 @@ def root_find_problem(y, args):
else:
fn = root_find_problem
dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init)
t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
)

def root_find(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
Expand Down