Skip to content

Commit d73b01e

Browse files
committed
complex fixes
1 parent 890539f commit d73b01e

File tree

6 files changed

+32
-15
lines changed

6 files changed

+32
-15
lines changed

optimistix/_solver/fixed_point.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def step(
5757
) -> tuple[Y, _FixedPointState, Aux]:
5858
new_y, aux = fn(y, args)
5959
error = (y**ω - new_y**ω).ω
60-
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
61-
new_state = _FixedPointState(self.norm((error**ω / scale**ω).ω))
60+
with jax.numpy_dtype_promotion("standard"):
61+
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
62+
new_state = _FixedPointState(self.norm((error**ω / scale**ω).ω))
6263
return new_y, new_state, aux
6364

6465
def terminate(

optimistix/_solver/levenberg_marquardt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import cast, Generic, Union
33

44
import equinox as eqx
5+
import jax
56
import jax.lax as lax
67
import jax.numpy as jnp
78
import jax.tree_util as jtu
@@ -25,7 +26,8 @@ class _Damped(eqx.Module, strict=True):
2526

2627
def __call__(self, y: PyTree[Array]):
2728
residual = self.operator.mv(y)
28-
damped = jtu.tree_map(lambda yi: jnp.sqrt(self.damping) * yi, y)
29+
with jax.numpy_dtype_promotion("standard"):
30+
damped = jtu.tree_map(lambda yi: jnp.sqrt(self.damping) * yi, y)
2931
return residual, damped
3032

3133

optimistix/_solver/newton_chord.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def step(
135135
if lower is not None or upper is not None:
136136
diff = (y**ω - new_y**ω).ω
137137
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
138-
diffsize = self.norm((diff**ω / scale**ω).ω)
138+
with jax.numpy_dtype_promotion("standard"):
139+
diffsize = self.norm((diff**ω / scale**ω).ω)
139140
if self.cauchy_termination:
140141
f_val = fx
141142
else:

tests/helpers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ def finite_difference_jvp(fn, primals, tangents, eps=None, **kwargs):
4040
# done to a tolerance of 1e-8 or so: the primal pass is already noisy at about
4141
# the scale of ε.
4242
ε = eps
43-
primals_ε = (ω(primals) + ε * ω(tangents)).ω
44-
out_ε = fn(*primals_ε, **kwargs)
45-
tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out)
43+
with jax.numpy_dtype_promotion("standard"):
44+
primals_ε = (ω(primals) + ε * ω(tangents)).ω
45+
out_ε = fn(*primals_ε, **kwargs)
46+
tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out)
4647
# We actually return the perturbed primal.
4748
# This should still be within all tolerance checks, and means that we have aceesss
4849
# to both the true primal and the perturbed primal when debugging.
@@ -488,7 +489,8 @@ def _getsize(y: PyTree[Array]):
488489
def _laplacian(y: PyTree[Array], dx: Scalar):
489490
(y, unflatten) = jfu.ravel_pytree(y)
490491
laplacian = jnp.zeros_like(y)
491-
laplacian = laplacian.at[1:-1].set((y[2:] + y[1:-1] + y[:-2]) / dx)
492+
with jax.numpy_dtype_promotion("standard"):
493+
laplacian = laplacian.at[1:-1].set((y[2:] + y[1:-1] + y[:-2]) / dx)
492494
return unflatten(y)
493495

494496

@@ -508,7 +510,8 @@ def _nonlinear_heat_pde_general(
508510
const = args
509511
stepsize = t1 - t0
510512
f_val = ((1 - y**ω) * _laplacian(y, dx) ** ω).ω
511-
return const * (y0**ω + 0.5 * stepsize * (f_val**ω + f0**ω)).ω
513+
with jax.numpy_dtype_promotion("standard"):
514+
return const * (y0**ω + 0.5 * stepsize * (f_val**ω + f0**ω)).ω
512515

513516

514517
# Note that the midpoint methods below assume that `f` is autonomous.

tests/test_fixed_point.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def test_fixed_point(solver, _fn, init, args):
4545

4646
@pytest.mark.parametrize("solver", _fp_solvers)
4747
@pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args)
48-
def test_fixed_point_jvp(getkey, solver, _fn, init, args):
48+
@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
49+
def test_fixed_point_jvp(getkey, solver, _fn, init, dtype, args):
50+
args = jtu.tree_map(lambda x: x.astype(dtype), args)
51+
init = jtu.tree_map(lambda x: x.astype(dtype), init)
4952
atol = rtol = 1e-3
5053
has_aux = random.choice([True, False])
5154
if has_aux:
@@ -54,8 +57,10 @@ def test_fixed_point_jvp(getkey, solver, _fn, init, args):
5457
fn = _fn
5558

5659
dynamic_args, static_args = eqx.partition(args, eqx.is_array)
57-
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init)
58-
t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args)
60+
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
61+
t_dynamic_args = jtu.tree_map(
62+
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
63+
)
5964

6065
def fixed_point(x, dynamic_args, *, adjoint):
6166
args = eqx.combine(dynamic_args, static_args)

tests/test_root_find.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def root_find_problem(y, args):
5252

5353
@pytest.mark.parametrize("solver", _root_finders)
5454
@pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args)
55-
def test_root_find_jvp(getkey, solver, _fn, init, args):
55+
@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
56+
def test_root_find_jvp(getkey, solver, _fn, init, dtype, args):
57+
args = jtu.tree_map(lambda x: x.astype(dtype), args)
58+
init = jtu.tree_map(lambda x: x.astype(dtype), init)
5659
atol = rtol = 1e-3
5760
has_aux = random.choice([True, False])
5861

@@ -65,8 +68,10 @@ def root_find_problem(y, args):
6568
else:
6669
fn = root_find_problem
6770
dynamic_args, static_args = eqx.partition(args, eqx.is_array)
68-
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init)
69-
t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args)
71+
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
72+
t_dynamic_args = jtu.tree_map(
73+
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
74+
)
7075

7176
def root_find(x, dynamic_args, *, adjoint):
7277
args = eqx.combine(dynamic_args, static_args)

0 commit comments

Comments
 (0)