Skip to content

Commit ced7b6b

Browse files
committed
New sde related fixes
1 parent 433cc1b commit ced7b6b

File tree

7 files changed

+64
-54
lines changed

7 files changed

+64
-54
lines changed

diffrax/_brownian/tree.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
9191
assert len(x1) == 2
9292
dt0, w0 = x0
9393
dt1, w1 = x1
94-
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
94+
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
9595
return BrownianIncrement(dt=su, W=w1 - w0)
9696

9797
elif len(x0) == 4: # space-time levy area case
@@ -100,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
100100
dt1, w1, hh1, bhh1 = x1
101101

102102
w_su = w1 - w0
103-
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
103+
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
104104
_su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
105105
inverse_su = 1 / _su
106-
u_bb_s = dt1 * w0 - dt0 * w1
107-
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
108-
hh_su = inverse_su * bhh_su
106+
with jax.numpy_dtype_promotion("standard"):
107+
u_bb_s = dt1 * w0 - dt0 * w1
108+
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
109+
hh_su = inverse_su * bhh_su
109110
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su)
110111
else:
111112
assert False
@@ -396,27 +397,31 @@ def _body_fun(_state: _State):
396397
a = d_prime * sr3 * sr_ru_half
397398
b = d_prime * ru3 * sr_ru_half
398399

399-
w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
400-
w_r = w_s + w_sr
401-
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
402-
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
403-
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
400+
with jax.numpy_dtype_promotion("standard"):
401+
w_sr = (
402+
sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
403+
)
404+
w_r = w_s + w_sr
405+
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
406+
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
407+
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
404408

405-
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
406-
hh_r = inverse_r * bhh_r
409+
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
410+
hh_r = inverse_r * bhh_r
407411

408412
elif self.levy_area is BrownianIncrement:
409-
w_mean = w_s + sr / su * w_su
410-
if self._spline == "sqrt":
411-
z = jr.normal(final_state.key, shape, dtype)
412-
bb = jnp.sqrt(sr * ru / su) * z
413-
elif self._spline == "quad":
414-
z = jr.normal(final_state.key, shape, dtype)
415-
bb = (sr * ru / su) * z
416-
elif self._spline == "zero":
417-
bb = jnp.zeros(shape, dtype)
418-
else:
419-
assert False
413+
with jax.numpy_dtype_promotion("standard"):
414+
w_mean = w_s + sr / su * w_su
415+
if self._spline == "sqrt":
416+
z = jr.normal(final_state.key, shape, dtype)
417+
bb = jnp.sqrt(sr * ru / su) * z
418+
elif self._spline == "quad":
419+
z = jr.normal(final_state.key, shape, dtype)
420+
bb = (sr * ru / su) * z
421+
elif self._spline == "zero":
422+
bb = jnp.zeros(shape, dtype)
423+
else:
424+
assert False
420425
w_r = w_mean + bb
421426
return r, w_r
422427

@@ -499,8 +504,8 @@ def _brownian_arch(
499504

500505
w_t = w_s + w_st
501506
w_stu = (w_s, w_t, w_u)
502-
503-
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
507+
with jax.numpy_dtype_promotion("standard"):
508+
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
504509
bhh_stu = (bhh_s, bhh_t, bhh_u)
505510
bkk_stu = None
506511
bkk_st_tu = None

diffrax/_integrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def _check(term_cls, term, term_contr_kwargs, yi):
155155
# If we've got to this point then the term is compatible
156156

157157
try:
158-
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
158+
with jax.numpy_dtype_promotion("standard"):
159+
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
159160
except ValueError:
160161
# ValueError may also arise from mismatched tree structures
161162
return False

diffrax/_solver/implicit_euler.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import ClassVar
33
from typing_extensions import TypeAlias
44

5-
import jax
65
import optimistix as optx
76
from equinox.internal import ω
87

@@ -82,15 +81,14 @@ def step(
8281
# write out a `ButcherTableau` and use `AbstractSDIRK`.
8382
k0 = terms.vf_prod(t0, y0, args, control)
8483
args = (terms.vf_prod, t1, y0, args, control)
85-
with jax.numpy_dtype_promotion("standard"):
86-
nonlinear_sol = optx.root_find(
87-
_implicit_relation,
88-
self.root_finder,
89-
k0,
90-
args,
91-
throw=False,
92-
max_steps=self.root_find_max_steps,
93-
)
84+
nonlinear_sol = optx.root_find(
85+
_implicit_relation,
86+
self.root_finder,
87+
k0,
88+
args,
89+
throw=False,
90+
max_steps=self.root_find_max_steps,
91+
)
9492
k1 = nonlinear_sol.value
9593
y1 = (y0**ω + k1**ω).ω
9694
# Use the trapezoidal rule for adaptive step sizing.

diffrax/_solver/srk.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
from equinox.internal import ω
1414
from jaxtyping import Array, Float, PyTree
15+
from lineax.internal import complex_to_real_dtype
1516

1617
from .._custom_types import (
1718
AbstractBrownianIncrement,
@@ -340,7 +341,9 @@ def step(
340341

341342
# First the drift related stuff
342343
a = self._embed_a_lower(self.tableau.a, dtype)
343-
c = jnp.asarray(np.insert(self.tableau.c, 0, 0.0), dtype=dtype)
344+
c = jnp.asarray(
345+
np.insert(self.tableau.c, 0, 0.0), dtype=complex_to_real_dtype(dtype)
346+
)
344347
b_sol = jnp.asarray(self.tableau.b_sol, dtype=dtype)
345348

346349
def make_zeros():

test/helpers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from jax import Array
2020
from jaxtyping import PRNGKeyArray, PyTree, Shaped
21-
from lineax.internal import complex_to_real_dtype
2221

2322

2423
all_ode_solvers = (
@@ -252,7 +251,6 @@ def sde_solver_strong_order(
252251
bm_tol,
253252
saveat,
254253
)
255-
dts = 2.0 ** jnp.arange(-3, -3 - num_levels, -1, dtype=dtype)
256254

257255
errs_list, steps_list = [], []
258256
for level in range(level_coarse, level_fine + 1):
@@ -277,7 +275,8 @@ def sde_solver_strong_order(
277275
steps_list.append(jnp.average(steps))
278276
errs_arr = jnp.array(errs_list)
279277
steps_arr = jnp.array(steps_list)
280-
order, _ = jnp.polyfit(jnp.log(1 / steps_arr), jnp.log(errs_arr), 1)
278+
with jax.numpy_dtype_promotion("standard"):
279+
order, _ = jnp.polyfit(jnp.log(1 / steps_arr), jnp.log(errs_arr), 1)
281280
return steps_arr, errs_arr, order
282281

283282

@@ -360,12 +359,14 @@ def _squareplus(x):
360359

361360
def drift(t, y, args):
362361
mlp, _, _ = args
363-
return 0.25 * mlp(y)
362+
with jax.numpy_dtype_promotion("standard"):
363+
return 0.25 * mlp(y)
364364

365365

366366
def diffusion(t, y, args):
367367
_, mlp, noise_dim = args
368-
return 1.0 * mlp(y).reshape(3, noise_dim)
368+
with jax.numpy_dtype_promotion("standard"):
369+
return 1.0 * mlp(y).reshape(3, noise_dim)
369370

370371

371372
def get_mlp_sde(t0, t1, dtype, key, noise_dim):
@@ -447,8 +448,9 @@ def ft(t):
447448
drift_mlp = init_linear_weight(drift_mlp, lap_init, driftkey)
448449

449450
def _drift(t, y, _):
450-
mlp_out = drift_mlp(jnp.concatenate([y, ft(t)]))
451-
return (0.01 * mlp_out - 0.5 * y**3) / (jnp.sum(y**2) + 1)
451+
with jax.numpy_dtype_promotion("standard"):
452+
mlp_out = drift_mlp(jnp.concatenate([y, ft(t)]))
453+
return (0.01 * mlp_out - 0.5 * y**3) / (jnp.sum(y**2) + 1)
452454

453455
diffusion_mx = jr.normal(diffusionkey, (4, y_dim, noise_dim), dtype=dtype)
454456

test/test_sde.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,12 @@ def _solvers_and_orders():
4343
# converges to its own limit (i.e. using itself as reference), and then in a
4444
# different test check whether that limit is the same as the Euler/Heun limit.
4545
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
46+
@pytest.mark.parametrize(
47+
"dtype",
48+
(jnp.float64,),
49+
)
4650
def test_sde_strong_order_new(
47-
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order
51+
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype
4852
):
4953
bmkey = jr.PRNGKey(5678)
5054
sde_key = jr.PRNGKey(11)
@@ -54,15 +58,15 @@ def test_sde_strong_order_new(
5458
t1 = 5.3
5559

5660
if noise == "add":
57-
sde = get_time_sde(t0, t1, jnp.float64, sde_key, noise_dim=7)
61+
sde = get_time_sde(t0, t1, dtype, sde_key, noise_dim=7)
5862
else:
5963
if noise == "com":
6064
noise_dim = 1
6165
elif noise == "any":
6266
noise_dim = 5
6367
else:
6468
assert False
65-
sde = get_mlp_sde(t0, t1, jnp.float64, sde_key, noise_dim=noise_dim)
69+
sde = get_mlp_sde(t0, t1, dtype, sde_key, noise_dim=noise_dim)
6670

6771
ref_solver = solver_ctr()
6872
level_coarse, level_fine = 1, 7

test/test_solver.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def order(self, terms):
274274

275275

276276
# Essentially used as a check that our general IMEX implementation is correct.
277-
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
277+
@pytest.mark.parametrize("dtype", (jnp.float64,))
278278
def test_sil3(dtype):
279279
class ReferenceSil3(diffrax.AbstractImplicitSolver):
280280
term_structure = diffrax.MultiTerm[
@@ -314,8 +314,7 @@ def _second_stage(ya, _):
314314
return ya - (y0 + (1 / 3) * f0 + (1 / 6) * g0 + (1 / 6) * g1)
315315

316316
ta = t0 + (1 / 3) * dt
317-
with jax.numpy_dtype_promotion("standard"):
318-
ya = optx.root_find(_second_stage, self.root_finder, y0).value
317+
ya = optx.root_find(_second_stage, self.root_finder, y0).value
319318
fs.append(ex_vf_prod(ta, ya))
320319
gs.append(im_vf_prod(ta, ya))
321320

@@ -329,8 +328,7 @@ def _third_stage(yb, _):
329328

330329
tb = t0 + (2 / 3) * dt
331330

332-
with jax.numpy_dtype_promotion("standard"):
333-
yb = optx.root_find(_third_stage, self.root_finder, ya).value
331+
yb = optx.root_find(_third_stage, self.root_finder, ya).value
334332
fs.append(ex_vf_prod(tb, yb))
335333
gs.append(im_vf_prod(tb, yb))
336334

@@ -349,8 +347,7 @@ def _fourth_stage(yc, _):
349347
)
350348

351349
tc = t1
352-
with jax.numpy_dtype_promotion("standard"):
353-
yc = optx.root_find(_fourth_stage, self.root_finder, yb).value
350+
yc = optx.root_find(_fourth_stage, self.root_finder, yb).value
354351
fs.append(ex_vf_prod(tc, yc))
355352
gs.append(im_vf_prod(tc, yc))
356353

0 commit comments

Comments
 (0)