Skip to content

Commit c4deca4

Browse files
authored
Enable more complex tests, fix related errors (#392)
* Fix complex tests * Fix more complex tests * New sde related fixes * New sde related fixes
1 parent a998093 commit c4deca4

20 files changed

+221
-128
lines changed

diffrax/_brownian/path.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jax.tree_util as jtu
1010
import lineax.internal as lxi
1111
from jaxtyping import Array, PRNGKeyArray, PyTree
12+
from lineax.internal import complex_to_real_dtype
1213

1314
from .._custom_types import (
1415
AbstractBrownianIncrement,
@@ -130,7 +131,7 @@ def _evaluate_leaf(
130131
):
131132
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
132133
w = jr.normal(key, shape.shape, shape.dtype) * w_std
133-
dt = jnp.asarray(t1 - t0, dtype=shape.dtype)
134+
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))
134135

135136
if levy_area is SpaceTimeLevyArea:
136137
key, key_hh = jr.split(key, 2)

diffrax/_brownian/tree.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import jax.random as jr
1111
import jax.tree_util as jtu
1212
import lineax.internal as lxi
13-
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
13+
from jaxtyping import Array, Inexact, PRNGKeyArray, PyTree
14+
from lineax.internal import complex_to_real_dtype
1415

1516
from .._custom_types import (
1617
AbstractBrownianIncrement,
@@ -54,9 +55,9 @@
5455
# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
5556
# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.
5657

57-
FloatDouble: TypeAlias = tuple[Float[Array, " *shape"], Float[Array, " *shape"]]
58+
FloatDouble: TypeAlias = tuple[Inexact[Array, " *shape"], Inexact[Array, " *shape"]]
5859
FloatTriple: TypeAlias = tuple[
59-
Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"]
60+
Inexact[Array, " *shape"], Inexact[Array, " *shape"], Inexact[Array, " *shape"]
6061
]
6162
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]
6263
_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement)
@@ -90,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
9091
assert len(x1) == 2
9192
dt0, w0 = x0
9293
dt1, w1 = x1
93-
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
94+
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
9495
return BrownianIncrement(dt=su, W=w1 - w0)
9596

9697
elif len(x0) == 4: # space-time levy area case
@@ -99,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
99100
dt1, w1, hh1, bhh1 = x1
100101

101102
w_su = w1 - w0
102-
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
103+
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
103104
_su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
104105
inverse_su = 1 / _su
105-
u_bb_s = dt1 * w0 - dt0 * w1
106-
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
107-
hh_su = inverse_su * bhh_su
106+
with jax.numpy_dtype_promotion("standard"):
107+
u_bb_s = dt1 * w0 - dt0 * w1
108+
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
109+
hh_su = inverse_su * bhh_su
108110
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su)
109111
else:
110112
assert False
@@ -283,9 +285,10 @@ def _evaluate_leaf(
283285
tuple[RealScalarLike, Array], tuple[RealScalarLike, Array, Array, Array]
284286
]:
285287
shape, dtype = struct.shape, struct.dtype
288+
tdtype = complex_to_real_dtype(dtype)
286289

287-
t0 = jnp.zeros((), dtype)
288-
r = jnp.asarray(r, dtype)
290+
t0 = jnp.zeros((), tdtype)
291+
r = jnp.asarray(r, tdtype)
289292

290293
if self.levy_area is SpaceTimeLevyArea:
291294
state_key, init_key_w, init_key_la = jr.split(key, 3)
@@ -394,27 +397,31 @@ def _body_fun(_state: _State):
394397
a = d_prime * sr3 * sr_ru_half
395398
b = d_prime * ru3 * sr_ru_half
396399

397-
w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
398-
w_r = w_s + w_sr
399-
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
400-
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
401-
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
400+
with jax.numpy_dtype_promotion("standard"):
401+
w_sr = (
402+
sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
403+
)
404+
w_r = w_s + w_sr
405+
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
406+
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
407+
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
402408

403-
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
404-
hh_r = inverse_r * bhh_r
409+
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
410+
hh_r = inverse_r * bhh_r
405411

406412
elif self.levy_area is BrownianIncrement:
407-
w_mean = w_s + sr / su * w_su
408-
if self._spline == "sqrt":
409-
z = jr.normal(final_state.key, shape, dtype)
410-
bb = jnp.sqrt(sr * ru / su) * z
411-
elif self._spline == "quad":
412-
z = jr.normal(final_state.key, shape, dtype)
413-
bb = (sr * ru / su) * z
414-
elif self._spline == "zero":
415-
bb = jnp.zeros(shape, dtype)
416-
else:
417-
assert False
413+
with jax.numpy_dtype_promotion("standard"):
414+
w_mean = w_s + sr / su * w_su
415+
if self._spline == "sqrt":
416+
z = jr.normal(final_state.key, shape, dtype)
417+
bb = jnp.sqrt(sr * ru / su) * z
418+
elif self._spline == "quad":
419+
z = jr.normal(final_state.key, shape, dtype)
420+
bb = (sr * ru / su) * z
421+
elif self._spline == "zero":
422+
bb = jnp.zeros(shape, dtype)
423+
else:
424+
assert False
418425
w_r = w_mean + bb
419426
return r, w_r
420427

@@ -497,8 +504,8 @@ def _brownian_arch(
497504

498505
w_t = w_s + w_st
499506
w_stu = (w_s, w_t, w_u)
500-
501-
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
507+
with jax.numpy_dtype_promotion("standard"):
508+
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
502509
bhh_stu = (bhh_s, bhh_t, bhh_u)
503510
bkk_stu = None
504511
bkk_st_tu = None

diffrax/_global_interpolation.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ def _index(_ys):
143143
next_t = self.ts[index + 1]
144144
diff_t = next_t - prev_t
145145

146-
return (prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)).ω
146+
with jax.numpy_dtype_promotion("standard"):
147+
return (
148+
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
149+
).ω
147150

148151
@eqx.filter_jit
149152
def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
@@ -165,10 +168,11 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
165168

166169
index, _ = self._interpret_t(t, left)
167170

168-
return (
169-
(ω(self.ys)[index + 1] - ω(self.ys)[index])
170-
/ (self.ts[index + 1] - self.ts[index])
171-
).ω
171+
with jax.numpy_dtype_promotion("standard"):
172+
return (
173+
(ω(self.ys)[index + 1] - ω(self.ys)[index])
174+
/ (self.ts[index + 1] - self.ts[index])
175+
).ω
172176

173177

174178
LinearInterpolation.__init__.__doc__ = """**Arguments:**
@@ -254,10 +258,11 @@ def evaluate(
254258

255259
d, c, b, a = self.coeffs
256260

257-
return (
258-
ω(a)[index]
259-
+ frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index]))
260-
).ω
261+
with jax.numpy_dtype_promotion("standard"):
262+
return (
263+
ω(a)[index]
264+
+ frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index]))
265+
).ω
261266

262267
@eqx.filter_jit
263268
def derivative(
@@ -283,7 +288,8 @@ def derivative(
283288

284289
d, c, b, _ = self.coeffs
285290

286-
return (ω(b)[index] + frac * (2 * ω(c)[index] + frac * 3 * ω(d)[index])).ω
291+
with jax.numpy_dtype_promotion("standard"):
292+
return (ω(b)[index] + frac * (2 * ω(c)[index] + frac * 3 * ω(d)[index])).ω
287293

288294

289295
CubicInterpolation.__init__.__doc__ = """**Arguments:**
@@ -622,8 +628,9 @@ def _hermite_forward(
622628
]:
623629
prev_ti, prev_yi, prev_deriv_i = carry
624630
ti, yi, next_ti, next_yi = value
625-
first_deriv_i = (next_yi - yi) / (next_ti - ti)
626-
later_deriv_i = (yi - prev_yi) / (ti - prev_ti)
631+
with jax.numpy_dtype_promotion("standard"):
632+
first_deriv_i = (next_yi - yi) / (next_ti - ti)
633+
later_deriv_i = (yi - prev_yi) / (ti - prev_ti)
627634
deriv_i = jnp.where(jnp.isnan(prev_yi), first_deriv_i, later_deriv_i)
628635
cond = jnp.isnan(yi)
629636
carry_ti = jnp.where(cond, prev_ti, ti)
@@ -635,13 +642,15 @@ def _hermite_forward(
635642

636643
def _hermite_coeffs(t0, y0, deriv0, t1, y1):
637644
t_diff = t1 - t0
638-
deriv1 = (y1 - y0) / t_diff
639-
d_deriv = deriv1 - deriv0
640645

641-
a = y0
642-
b = deriv0
643-
c = 2 * d_deriv / t_diff
644-
d = -d_deriv / t_diff**2
646+
with jax.numpy_dtype_promotion("standard"):
647+
deriv1 = (y1 - y0) / t_diff
648+
d_deriv = deriv1 - deriv0
649+
650+
a = y0
651+
b = deriv0
652+
c = 2 * d_deriv / t_diff
653+
d = -d_deriv / (t_diff**2)
645654

646655
return d, c, b, a
647656

@@ -684,7 +693,8 @@ def _backward_hermite_coefficients(
684693
else:
685694
y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)
686695
if deriv0 is None:
687-
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
696+
with jax.numpy_dtype_promotion("standard"):
697+
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
688698
else:
689699
deriv0 = jnp.broadcast_to(deriv0, ys[0].shape)
690700
ts = ts[:-1]

diffrax/_integrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def _check(term_cls, term, term_contr_kwargs, yi):
155155
# If we've got to this point then the term is compatible
156156

157157
try:
158-
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
158+
with jax.numpy_dtype_promotion("standard"):
159+
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
159160
except ValueError:
160161
# ValueError may also arise from mismatched tree structures
161162
return False

diffrax/_local_interpolation.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
22
from typing import cast, Optional, TYPE_CHECKING
33

4+
import jax
45
import jax.numpy as jnp
56
import jax.tree_util as jtu
67
import numpy as np
@@ -35,12 +36,15 @@ def evaluate(
3536
self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True
3637
) -> PyTree[Array]:
3738
del left
38-
if t1 is None:
39-
coeff = linear_rescale(self.t0, t0, self.t1)
40-
return (self.y0**ω + coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
41-
else:
42-
coeff = (t1 - t0) / (self.t1 - self.t0)
43-
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
39+
with jax.numpy_dtype_promotion("standard"):
40+
if t1 is None:
41+
coeff = linear_rescale(self.t0, t0, self.t1)
42+
return (
43+
(self.y0**ω + coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
44+
)
45+
else:
46+
coeff = (t1 - t0) / (self.t1 - self.t0)
47+
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
4448

4549

4650
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
@@ -82,7 +86,8 @@ def evaluate(
8286
t = linear_rescale(self.t0, t0, self.t1)
8387

8488
def _eval(_coeffs):
85-
return jnp.polyval(_coeffs, t)
89+
with jax.numpy_dtype_promotion("standard"):
90+
return jnp.polyval(_coeffs, t)
8691

8792
return jtu.tree_map(_eval, self.coeffs)
8893

@@ -104,7 +109,8 @@ def __init__(
104109
k: PyTree[Shaped[Array, "order ?*y"], "Y"],
105110
):
106111
def _calculate(_y0, _y1, _k):
107-
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
112+
with jax.numpy_dtype_promotion("standard"):
113+
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
108114
_f0 = _k[0]
109115
_f1 = _k[-1]
110116
# TODO: rewrite as matrix-vector product?
@@ -127,6 +133,7 @@ def evaluate(
127133
t = linear_rescale(self.t0, t0, self.t1)
128134

129135
def _eval(_coeffs):
130-
return jnp.polyval(_coeffs, t)
136+
with jax.numpy_dtype_promotion("standard"):
137+
return jnp.polyval(_coeffs, t)
131138

132139
return jtu.tree_map(_eval, self.coeffs)

diffrax/_root_finder/_verychord.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import optimistix as optx
1212
from equinox.internal import ω
1313
from jaxtyping import Array, Bool, PyTree, Scalar
14+
from lineax.internal import complex_to_real_dtype
1415

1516
from .._custom_types import Y
1617

@@ -97,11 +98,12 @@ def init(
9798
y_dtype = lxi.default_floating_dtype()
9899
else:
99100
y_dtype = jnp.result_type(*y_leaves)
101+
diff_dtype = complex_to_real_dtype(y_dtype)
100102
init_state = _VeryChordState(
101103
linear_state=linear_state,
102104
diff=jtu.tree_map(lambda x: jnp.full(x.shape, jnp.inf, x.dtype), y),
103-
diffsize=jnp.array(jnp.inf, dtype=y_dtype),
104-
diffsize_prev=jnp.array(1.0, dtype=y_dtype),
105+
diffsize=jnp.array(jnp.inf, dtype=diff_dtype),
106+
diffsize_prev=jnp.array(1.0, dtype=diff_dtype),
105107
result=optx.RESULTS.successful,
106108
step=jnp.array(0),
107109
)
@@ -127,8 +129,10 @@ def step(
127129
)
128130
diff = sol.value
129131
new_y = (y**ω - diff**ω).ω
130-
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
131-
diffsize = self.norm((diff**ω / scale**ω).ω)
132+
133+
with jax.numpy_dtype_promotion("standard"):
134+
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
135+
diffsize = self.norm((diff**ω / scale**ω).ω)
132136
new_state = _VeryChordState(
133137
linear_state=state.linear_state,
134138
diff=diff,

diffrax/_solver/dopri8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ def evaluate(
298298
return self.evaluate(t1) - self.evaluate(t0)
299299
t = linear_rescale(self.t0, t0, self.t1)
300300
coeffs = _vmap_polyval(jnp.asarray(self.eval_coeffs, dtype=t.dtype), t) * t
301-
return (self.y0**ω + vector_tree_dot(coeffs, self.k) ** ω).ω
301+
with jax.numpy_dtype_promotion("standard"):
302+
return (self.y0**ω + vector_tree_dot(coeffs, self.k) ** ω).ω
302303

303304

304305
class Dopri8(AbstractERK):

diffrax/_solver/kencarp3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def evaluate(
117117
explicit_k, implicit_k = self.k
118118
k = (explicit_k**ω + implicit_k**ω).ω
119119
coeffs = t * jax.vmap(lambda row: jnp.polyval(row, t))(self.coeffs)
120-
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω
120+
with jax.numpy_dtype_promotion("standard"):
121+
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω
121122

122123

123124
class _KenCarp3Interpolation(KenCarpInterpolation):

diffrax/_solver/milstein.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ def step(
214214
leaf = jnp.tensordot(l1[..., None], l2[None, ...], axes=1)
215215
if i1 == i2:
216216
eye = jnp.eye(l1.size).reshape(l1.shape + l1.shape)
217-
leaf = leaf - Δt * eye
217+
with jax.numpy_dtype_promotion("standard"):
218+
leaf = leaf - Δt * eye
218219
leaves_ΔwΔw.append(leaf)
219220
tree_ΔwΔw = tree_Δw.compose(tree_Δw)
220221
ΔwΔw = jtu.tree_unflatten(tree_ΔwΔw, leaves_ΔwΔw)
@@ -236,7 +237,9 @@ def _to_vmap(_g0):
236237
# _g0 has structure (tree(y0), leaf(y0))
237238
_, _jvp = jax.jvp(_to_vjp, (y0,), (_g0,))
238239
# jvp has structure (tree(g0), leaf(g0))
239-
_jvp_matrix = jax.jacfwd(lambda _Δw: diffusion.prod(_jvp, _Δw))(Δw)
240+
_jvp_matrix = jax.jacfwd(
241+
lambda _Δw: diffusion.prod(_jvp, _Δw), holomorphic=jnp.iscomplexobj(Δw)
242+
)(Δw)
240243
# _jvp_matrix has structure (tree(y0), tree(Δw), leaf(y0), leaf(Δw))
241244
return _jvp_matrix
242245

@@ -282,7 +285,9 @@ def _to_treemap(_Δw, _g0):
282285
Δw_treedef = jtu.tree_structure(Δw)
283286
# g0 has structure (tree(g0), leaf(g0))
284287
# Which we now transform into its isomorphic matrix form, as above.
285-
g0_matrix = jax.jacfwd(lambda _Δw: diffusion.prod(g0, _Δw))(Δw)
288+
g0_matrix = jax.jacfwd(
289+
lambda _Δw: diffusion.prod(g0, _Δw), holomorphic=jnp.iscomplexobj(Δw)
290+
)(Δw)
286291
# g0_matrix has structure (tree(y0), tree(Δw), leaf(y0), leaf(Δw))
287292
g0_matrix = jtu.tree_transpose(y_treedef, Δw_treedef, g0_matrix)
288293
# g0_matrix has structure (tree(Δw), tree(y0), leaf(y0), leaf(Δw))

0 commit comments

Comments
 (0)