diff --git a/optimistix/_solver/fixed_point.py b/optimistix/_solver/fixed_point.py index 1b8f64b8..b2835bc7 100644 --- a/optimistix/_solver/fixed_point.py +++ b/optimistix/_solver/fixed_point.py @@ -57,8 +57,9 @@ def step( ) -> tuple[Y, _FixedPointState, Aux]: new_y, aux = fn(y, args) error = (y**ω - new_y**ω).ω - scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω - new_state = _FixedPointState(self.norm((error**ω / scale**ω).ω)) + with jax.numpy_dtype_promotion("standard"): + scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω + new_state = _FixedPointState(self.norm((error**ω / scale**ω).ω)) return new_y, new_state, aux def terminate( diff --git a/optimistix/_solver/levenberg_marquardt.py b/optimistix/_solver/levenberg_marquardt.py index 3b0db341..031a9438 100644 --- a/optimistix/_solver/levenberg_marquardt.py +++ b/optimistix/_solver/levenberg_marquardt.py @@ -2,6 +2,7 @@ from typing import cast, Generic, Union import equinox as eqx +import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -25,7 +26,8 @@ class _Damped(eqx.Module, strict=True): def __call__(self, y: PyTree[Array]): residual = self.operator.mv(y) - damped = jtu.tree_map(lambda yi: jnp.sqrt(self.damping) * yi, y) + with jax.numpy_dtype_promotion("standard"): + damped = jtu.tree_map(lambda yi: jnp.sqrt(self.damping) * yi, y) return residual, damped diff --git a/optimistix/_solver/newton_chord.py b/optimistix/_solver/newton_chord.py index 528cfea3..3d05cf05 100644 --- a/optimistix/_solver/newton_chord.py +++ b/optimistix/_solver/newton_chord.py @@ -135,7 +135,8 @@ def step( if lower is not None or upper is not None: diff = (y**ω - new_y**ω).ω scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω - diffsize = self.norm((diff**ω / scale**ω).ω) + with jax.numpy_dtype_promotion("standard"): + diffsize = self.norm((diff**ω / scale**ω).ω) if self.cauchy_termination: f_val = fx else: diff --git a/tests/helpers.py b/tests/helpers.py index 40f7afba..29c57674 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -40,9 +40,10 @@ def finite_difference_jvp(fn, primals, tangents, eps=None, **kwargs): # done to a tolerance of 1e-8 or so: the primal pass is already noisy at about # the scale of ε. ε = eps - primals_ε = (ω(primals) + ε * ω(tangents)).ω - out_ε = fn(*primals_ε, **kwargs) - tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out) + with jax.numpy_dtype_promotion("standard"): + primals_ε = (ω(primals) + ε * ω(tangents)).ω + out_ε = fn(*primals_ε, **kwargs) + tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out) # We actually return the perturbed primal. # This should still be within all tolerance checks, and means that we have aceesss # to both the true primal and the perturbed primal when debugging. @@ -488,7 +489,8 @@ def _getsize(y: PyTree[Array]): def _laplacian(y: PyTree[Array], dx: Scalar): (y, unflatten) = jfu.ravel_pytree(y) laplacian = jnp.zeros_like(y) - laplacian = laplacian.at[1:-1].set((y[2:] + y[1:-1] + y[:-2]) / dx) + with jax.numpy_dtype_promotion("standard"): + laplacian = laplacian.at[1:-1].set((y[2:] + y[1:-1] + y[:-2]) / dx) return unflatten(y) @@ -508,7 +510,8 @@ def _nonlinear_heat_pde_general( const = args stepsize = t1 - t0 f_val = ((1 - y**ω) * _laplacian(y, dx) ** ω).ω - return const * (y0**ω + 0.5 * stepsize * (f_val**ω + f0**ω)).ω + with jax.numpy_dtype_promotion("standard"): + return const * (y0**ω + 0.5 * stepsize * (f_val**ω + f0**ω)).ω # Note that the midpoint methods below assume that `f` is autonomous. diff --git a/tests/test_fixed_point.py b/tests/test_fixed_point.py index 3d5c8c8d..050935a2 100644 --- a/tests/test_fixed_point.py +++ b/tests/test_fixed_point.py @@ -45,7 +45,10 @@ def test_fixed_point(solver, _fn, init, args): @pytest.mark.parametrize("solver", _fp_solvers) @pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args) -def test_fixed_point_jvp(getkey, solver, _fn, init, args): +@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128]) +def test_fixed_point_jvp(getkey, solver, _fn, init, dtype, args): + args = jtu.tree_map(lambda x: x.astype(dtype), args) + init = jtu.tree_map(lambda x: x.astype(dtype), init) atol = rtol = 1e-3 has_aux = random.choice([True, False]) if has_aux: @@ -54,8 +57,10 @@ def test_fixed_point_jvp(getkey, solver, _fn, init, args): fn = _fn dynamic_args, static_args = eqx.partition(args, eqx.is_array) - t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init) - t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args) + t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init) + t_dynamic_args = jtu.tree_map( + lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args + ) def fixed_point(x, dynamic_args, *, adjoint): args = eqx.combine(dynamic_args, static_args) diff --git a/tests/test_root_find.py b/tests/test_root_find.py index c9c97300..2c76c01c 100644 --- a/tests/test_root_find.py +++ b/tests/test_root_find.py @@ -52,7 +52,10 @@ def root_find_problem(y, args): @pytest.mark.parametrize("solver", _root_finders) @pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args) -def test_root_find_jvp(getkey, solver, _fn, init, args): +@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128]) +def test_root_find_jvp(getkey, solver, _fn, init, dtype, args): + args = jtu.tree_map(lambda x: x.astype(dtype), args) + init = jtu.tree_map(lambda x: x.astype(dtype), init) atol = rtol = 1e-3 has_aux = random.choice([True, False]) @@ -65,8 +68,10 @@ def root_find_problem(y, args): else: fn = root_find_problem dynamic_args, static_args = eqx.partition(args, eqx.is_array) - t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init) - t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args) + t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init) + t_dynamic_args = jtu.tree_map( + lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args + ) def root_find(x, dynamic_args, *, adjoint): args = eqx.combine(dynamic_args, static_args)