Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import Array, PRNGKeyArray, PyTree
from lineax.internal import complex_to_real_dtype

from .._custom_types import (
AbstractBrownianIncrement,
Expand Down Expand Up @@ -130,7 +131,7 @@ def _evaluate_leaf(
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
w = jr.normal(key, shape.shape, shape.dtype) * w_std
dt = jnp.asarray(t1 - t0, dtype=shape.dtype)
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))

if levy_area is SpaceTimeLevyArea:
key, key_hh = jr.split(key, 2)
Expand Down
67 changes: 37 additions & 30 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import jax.random as jr
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from jaxtyping import Array, Inexact, PRNGKeyArray, PyTree
from lineax.internal import complex_to_real_dtype

from .._custom_types import (
AbstractBrownianIncrement,
Expand Down Expand Up @@ -54,9 +55,9 @@
# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.

FloatDouble: TypeAlias = tuple[Float[Array, " *shape"], Float[Array, " *shape"]]
FloatDouble: TypeAlias = tuple[Inexact[Array, " *shape"], Inexact[Array, " *shape"]]
FloatTriple: TypeAlias = tuple[
Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"]
Inexact[Array, " *shape"], Inexact[Array, " *shape"], Inexact[Array, " *shape"]
]
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]
_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement)
Expand Down Expand Up @@ -90,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
assert len(x1) == 2
dt0, w0 = x0
dt1, w1 = x1
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
return BrownianIncrement(dt=su, W=w1 - w0)

elif len(x0) == 4: # space-time levy area case
Expand All @@ -99,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
dt1, w1, hh1, bhh1 = x1

w_su = w1 - w0
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
_su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
inverse_su = 1 / _su
u_bb_s = dt1 * w0 - dt0 * w1
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
hh_su = inverse_su * bhh_su
with jax.numpy_dtype_promotion("standard"):
u_bb_s = dt1 * w0 - dt0 * w1
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
hh_su = inverse_su * bhh_su
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su)
else:
assert False
Expand Down Expand Up @@ -283,9 +285,10 @@ def _evaluate_leaf(
tuple[RealScalarLike, Array], tuple[RealScalarLike, Array, Array, Array]
]:
shape, dtype = struct.shape, struct.dtype
tdtype = complex_to_real_dtype(dtype)

t0 = jnp.zeros((), dtype)
r = jnp.asarray(r, dtype)
t0 = jnp.zeros((), tdtype)
r = jnp.asarray(r, tdtype)

if self.levy_area is SpaceTimeLevyArea:
state_key, init_key_w, init_key_la = jr.split(key, 3)
Expand Down Expand Up @@ -394,27 +397,31 @@ def _body_fun(_state: _State):
a = d_prime * sr3 * sr_ru_half
b = d_prime * ru3 * sr_ru_half

w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
w_r = w_s + w_sr
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
with jax.numpy_dtype_promotion("standard"):
w_sr = (
sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
)
w_r = w_s + w_sr
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)

inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
hh_r = inverse_r * bhh_r
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
hh_r = inverse_r * bhh_r

elif self.levy_area is BrownianIncrement:
w_mean = w_s + sr / su * w_su
if self._spline == "sqrt":
z = jr.normal(final_state.key, shape, dtype)
bb = jnp.sqrt(sr * ru / su) * z
elif self._spline == "quad":
z = jr.normal(final_state.key, shape, dtype)
bb = (sr * ru / su) * z
elif self._spline == "zero":
bb = jnp.zeros(shape, dtype)
else:
assert False
with jax.numpy_dtype_promotion("standard"):
w_mean = w_s + sr / su * w_su
if self._spline == "sqrt":
z = jr.normal(final_state.key, shape, dtype)
bb = jnp.sqrt(sr * ru / su) * z
elif self._spline == "quad":
z = jr.normal(final_state.key, shape, dtype)
bb = (sr * ru / su) * z
elif self._spline == "zero":
bb = jnp.zeros(shape, dtype)
else:
assert False
w_r = w_mean + bb
return r, w_r

Expand Down Expand Up @@ -497,8 +504,8 @@ def _brownian_arch(

w_t = w_s + w_st
w_stu = (w_s, w_t, w_u)

bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
with jax.numpy_dtype_promotion("standard"):
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
bhh_stu = (bhh_s, bhh_t, bhh_u)
bkk_stu = None
bkk_st_tu = None
Expand Down
48 changes: 29 additions & 19 deletions diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def _index(_ys):
next_t = self.ts[index + 1]
diff_t = next_t - prev_t

return (prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)).ω
with jax.numpy_dtype_promotion("standard"):
return (
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
).ω

@eqx.filter_jit
def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
Expand All @@ -165,10 +168,11 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:

index, _ = self._interpret_t(t, left)

return (
(ω(self.ys)[index + 1] - ω(self.ys)[index])
/ (self.ts[index + 1] - self.ts[index])
).ω
with jax.numpy_dtype_promotion("standard"):
return (
(ω(self.ys)[index + 1] - ω(self.ys)[index])
/ (self.ts[index + 1] - self.ts[index])
).ω


LinearInterpolation.__init__.__doc__ = """**Arguments:**
Expand Down Expand Up @@ -254,10 +258,11 @@ def evaluate(

d, c, b, a = self.coeffs

return (
ω(a)[index]
+ frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index]))
).ω
with jax.numpy_dtype_promotion("standard"):
return (
ω(a)[index]
+ frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index]))
).ω

@eqx.filter_jit
def derivative(
Expand All @@ -283,7 +288,8 @@ def derivative(

d, c, b, _ = self.coeffs

return (ω(b)[index] + frac * (2 * ω(c)[index] + frac * 3 * ω(d)[index])).ω
with jax.numpy_dtype_promotion("standard"):
return (ω(b)[index] + frac * (2 * ω(c)[index] + frac * 3 * ω(d)[index])).ω


CubicInterpolation.__init__.__doc__ = """**Arguments:**
Expand Down Expand Up @@ -622,8 +628,9 @@ def _hermite_forward(
]:
prev_ti, prev_yi, prev_deriv_i = carry
ti, yi, next_ti, next_yi = value
first_deriv_i = (next_yi - yi) / (next_ti - ti)
later_deriv_i = (yi - prev_yi) / (ti - prev_ti)
with jax.numpy_dtype_promotion("standard"):
first_deriv_i = (next_yi - yi) / (next_ti - ti)
later_deriv_i = (yi - prev_yi) / (ti - prev_ti)
deriv_i = jnp.where(jnp.isnan(prev_yi), first_deriv_i, later_deriv_i)
cond = jnp.isnan(yi)
carry_ti = jnp.where(cond, prev_ti, ti)
Expand All @@ -635,13 +642,15 @@ def _hermite_forward(

def _hermite_coeffs(t0, y0, deriv0, t1, y1):
t_diff = t1 - t0
deriv1 = (y1 - y0) / t_diff
d_deriv = deriv1 - deriv0

a = y0
b = deriv0
c = 2 * d_deriv / t_diff
d = -d_deriv / t_diff**2
with jax.numpy_dtype_promotion("standard"):
deriv1 = (y1 - y0) / t_diff
d_deriv = deriv1 - deriv0

a = y0
b = deriv0
c = 2 * d_deriv / t_diff
d = -d_deriv / (t_diff**2)

return d, c, b, a

Expand Down Expand Up @@ -684,7 +693,8 @@ def _backward_hermite_coefficients(
else:
y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)
if deriv0 is None:
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
with jax.numpy_dtype_promotion("standard"):
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
else:
deriv0 = jnp.broadcast_to(deriv0, ys[0].shape)
ts = ts[:-1]
Expand Down
3 changes: 2 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def _check(term_cls, term, term_contr_kwargs, yi):
# If we've got to this point then the term is compatible

try:
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
with jax.numpy_dtype_promotion("standard"):
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
except ValueError:
# ValueError may also arise from mismatched tree structures
return False
Expand Down
25 changes: 16 additions & 9 deletions diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import cast, Optional, TYPE_CHECKING

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand Down Expand Up @@ -35,12 +36,15 @@ def evaluate(
self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True
) -> PyTree[Array]:
del left
if t1 is None:
coeff = linear_rescale(self.t0, t0, self.t1)
return (self.y0**ω + coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
else:
coeff = (t1 - t0) / (self.t1 - self.t0)
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
with jax.numpy_dtype_promotion("standard"):
if t1 is None:
coeff = linear_rescale(self.t0, t0, self.t1)
return (
(self.y0**ω + coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
)
else:
coeff = (t1 - t0) / (self.t1 - self.t0)
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω


class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
Expand Down Expand Up @@ -82,7 +86,8 @@ def evaluate(
t = linear_rescale(self.t0, t0, self.t1)

def _eval(_coeffs):
return jnp.polyval(_coeffs, t)
with jax.numpy_dtype_promotion("standard"):
return jnp.polyval(_coeffs, t)

return jtu.tree_map(_eval, self.coeffs)

Expand All @@ -104,7 +109,8 @@ def __init__(
k: PyTree[Shaped[Array, "order ?*y"], "Y"],
):
def _calculate(_y0, _y1, _k):
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
with jax.numpy_dtype_promotion("standard"):
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
_f0 = _k[0]
_f1 = _k[-1]
# TODO: rewrite as matrix-vector product?
Expand All @@ -127,6 +133,7 @@ def evaluate(
t = linear_rescale(self.t0, t0, self.t1)

def _eval(_coeffs):
return jnp.polyval(_coeffs, t)
with jax.numpy_dtype_promotion("standard"):
return jnp.polyval(_coeffs, t)

return jtu.tree_map(_eval, self.coeffs)
12 changes: 8 additions & 4 deletions diffrax/_root_finder/_verychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import optimistix as optx
from equinox.internal import ω
from jaxtyping import Array, Bool, PyTree, Scalar
from lineax.internal import complex_to_real_dtype

from .._custom_types import Y

Expand Down Expand Up @@ -97,11 +98,12 @@ def init(
y_dtype = lxi.default_floating_dtype()
else:
y_dtype = jnp.result_type(*y_leaves)
diff_dtype = complex_to_real_dtype(y_dtype)
init_state = _VeryChordState(
linear_state=linear_state,
diff=jtu.tree_map(lambda x: jnp.full(x.shape, jnp.inf, x.dtype), y),
diffsize=jnp.array(jnp.inf, dtype=y_dtype),
diffsize_prev=jnp.array(1.0, dtype=y_dtype),
diffsize=jnp.array(jnp.inf, dtype=diff_dtype),
diffsize_prev=jnp.array(1.0, dtype=diff_dtype),
result=optx.RESULTS.successful,
step=jnp.array(0),
)
Expand All @@ -127,8 +129,10 @@ def step(
)
diff = sol.value
new_y = (y**ω - diff**ω).ω
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
diffsize = self.norm((diff**ω / scale**ω).ω)

with jax.numpy_dtype_promotion("standard"):
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
diffsize = self.norm((diff**ω / scale**ω).ω)
new_state = _VeryChordState(
linear_state=state.linear_state,
diff=diff,
Expand Down
3 changes: 2 additions & 1 deletion diffrax/_solver/dopri8.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def evaluate(
return self.evaluate(t1) - self.evaluate(t0)
t = linear_rescale(self.t0, t0, self.t1)
coeffs = _vmap_polyval(jnp.asarray(self.eval_coeffs, dtype=t.dtype), t) * t
return (self.y0**ω + vector_tree_dot(coeffs, self.k) ** ω).ω
with jax.numpy_dtype_promotion("standard"):
return (self.y0**ω + vector_tree_dot(coeffs, self.k) ** ω).ω


class Dopri8(AbstractERK):
Expand Down
3 changes: 2 additions & 1 deletion diffrax/_solver/kencarp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def evaluate(
explicit_k, implicit_k = self.k
k = (explicit_k**ω + implicit_k**ω).ω
coeffs = t * jax.vmap(lambda row: jnp.polyval(row, t))(self.coeffs)
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω
with jax.numpy_dtype_promotion("standard"):
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω


class _KenCarp3Interpolation(KenCarpInterpolation):
Expand Down
11 changes: 8 additions & 3 deletions diffrax/_solver/milstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def step(
leaf = jnp.tensordot(l1[..., None], l2[None, ...], axes=1)
if i1 == i2:
eye = jnp.eye(l1.size).reshape(l1.shape + l1.shape)
leaf = leaf - Δt * eye
with jax.numpy_dtype_promotion("standard"):
leaf = leaf - Δt * eye
leaves_ΔwΔw.append(leaf)
tree_ΔwΔw = tree_Δw.compose(tree_Δw)
ΔwΔw = jtu.tree_unflatten(tree_ΔwΔw, leaves_ΔwΔw)
Expand All @@ -236,7 +237,9 @@ def _to_vmap(_g0):
# _g0 has structure (tree(y0), leaf(y0))
_, _jvp = jax.jvp(_to_vjp, (y0,), (_g0,))
# jvp has structure (tree(g0), leaf(g0))
_jvp_matrix = jax.jacfwd(lambda _Δw: diffusion.prod(_jvp, _Δw))(Δw)
_jvp_matrix = jax.jacfwd(
lambda _Δw: diffusion.prod(_jvp, _Δw), holomorphic=jnp.iscomplexobj(Δw)
)(Δw)
# _jvp_matrix has structure (tree(y0), tree(Δw), leaf(y0), leaf(Δw))
return _jvp_matrix

Expand Down Expand Up @@ -282,7 +285,9 @@ def _to_treemap(_Δw, _g0):
Δw_treedef = jtu.tree_structure(Δw)
# g0 has structure (tree(g0), leaf(g0))
# Which we now transform into its isomorphic matrix form, as above.
g0_matrix = jax.jacfwd(lambda _Δw: diffusion.prod(g0, _Δw))(Δw)
g0_matrix = jax.jacfwd(
lambda _Δw: diffusion.prod(g0, _Δw), holomorphic=jnp.iscomplexobj(Δw)
)(Δw)
# g0_matrix has structure (tree(y0), tree(Δw), leaf(y0), leaf(Δw))
g0_matrix = jtu.tree_transpose(y_treedef, Δw_treedef, g0_matrix)
# g0_matrix has structure (tree(Δw), tree(y0), leaf(y0), leaf(Δw))
Expand Down
Loading