Skip to content

Commit 21be399

Browse files
committed
New sde related fixes
1 parent ced7b6b commit 21be399

File tree

6 files changed

+48
-21
lines changed

6 files changed

+48
-21
lines changed

diffrax/_brownian/path.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jax.tree_util as jtu
1010
import lineax.internal as lxi
1111
from jaxtyping import Array, PRNGKeyArray, PyTree
12+
from lineax.internal import complex_to_real_dtype
1213

1314
from .._custom_types import (
1415
AbstractBrownianIncrement,
@@ -130,7 +131,7 @@ def _evaluate_leaf(
130131
):
131132
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
132133
w = jr.normal(key, shape.shape, shape.dtype) * w_std
133-
dt = jnp.asarray(t1 - t0, dtype=shape.dtype)
134+
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))
134135

135136
if levy_area is SpaceTimeLevyArea:
136137
key, key_hh = jr.split(key, 2)

diffrax/_solver/srk.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def aux_add_levy(w_leaf, *levy_leaves):
406406
def _comp_g(_t):
407407
return diffusion.vf(_t, y0, args)
408408

409-
g0_g1 = _comp_g(jnp.array([t0, t1], dtype=dtype))
409+
g0_g1 = _comp_g(jnp.array([t0, t1], dtype=complex_to_real_dtype(dtype)))
410410
g0 = jtu.tree_map(lambda g_leaf: g_leaf[0], g0_g1)
411411
# g_delta = 0.5 * g1 - g0
412412
g_delta = jtu.tree_map(lambda g_leaf: 0.5 * (g_leaf[1] - g_leaf[0]), g0_g1)
@@ -534,13 +534,15 @@ def compute_and_insert_kf_j(_h_kfs_in):
534534
return (_h_kfs, None, None), None
535535

536536
def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in):
537-
_w_kg_j = diffusion.vf_prod(t0 + c_j * h, z_j, args, w)
537+
with jax.numpy_dtype_promotion("standard"):
538+
_w_kg_j = diffusion.vf_prod(t0 + c_j * h, z_j, args, w)
538539
new_w_kgs = insert_jth_stage(_w_kgs_in, _w_kg_j, j)
539540

540-
_levylist_kg_j = [
541-
diffusion.vf_prod(t0 + c_j * h, z_j, args, levy)
542-
for levy in levy_areas
543-
]
541+
with jax.numpy_dtype_promotion("standard"):
542+
_levylist_kg_j = [
543+
diffusion.vf_prod(t0 + c_j * h, z_j, args, levy)
544+
for levy in levy_areas
545+
]
544546
new_levylist_kgs = insert_jth_stage(_levylist_kgs_in, _levylist_kg_j, j)
545547
return new_w_kgs, new_levylist_kgs
546548

diffrax/_term.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]):
361361
"""
362362

363363
def prod(self, vf: _VF, control: _Control) -> Y:
364-
return jtu.tree_map(operator.mul, vf, control)
364+
with jax.numpy_dtype_promotion("standard"):
365+
return jtu.tree_map(operator.mul, vf, control)
365366

366367

367368
class _ControlToODE(eqx.Module):
@@ -461,7 +462,8 @@ def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control:
461462
return (self.direction * self.term.contr(_t0, _t1, **kwargs) ** ω).ω
462463

463464
def prod(self, vf: _VF, control: _Control) -> Y:
464-
return self.term.prod(vf, control)
465+
with jax.numpy_dtype_promotion("standard"):
466+
return self.term.prod(vf, control)
465467

466468
def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y:
467469
t = t * self.direction

test/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def path_l2_dist(
9999
# all but the first two axes (which represent the number of samples
100100
# and the length of saveat). Also sum all the PyTree leaves.
101101
def sum_square_diff(y1, y2):
102-
square_diff = jnp.square(y1 - y2)
102+
with jax.numpy_dtype_promotion("standard"):
103+
square_diff = jnp.square(y1 - y2)
103104
# sum all but the first two axes
104105
axes = range(2, square_diff.ndim)
105106
out = jnp.sum(square_diff, axis=axes)

test/test_brownian.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def test_shape_and_dtype(ctr, levy_area, use_levy, getkey):
4646
(2,),
4747
(3, 4),
4848
(1, 2, 3, 4),
49+
(1, 2, 3, 4),
50+
{
51+
"a": (1,),
52+
"b": (2, 3),
53+
},
4954
{
5055
"a": (1,),
5156
"b": (2, 3),
@@ -66,7 +71,9 @@ def test_shape_and_dtype(ctr, levy_area, use_levy, getkey):
6671
jnp.float16,
6772
jnp.float32,
6873
jnp.float64,
74+
jnp.complex128,
6975
{"a": None, "b": jnp.float64},
76+
{"a": jnp.float64, "b": jnp.complex128},
7077
(jnp.float16, (jnp.float32, jnp.float64)),
7178
)
7279

test/test_sde.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,12 @@ def get_dt_and_controller(level):
116116
# using a single reference solution. We use Euler if the solver is Ito
117117
# and Heun if the solver is Stratonovich.
118118
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
119+
@pytest.mark.parametrize(
120+
"dtype",
121+
(jnp.float64,),
122+
)
119123
def test_sde_strong_limit(
120-
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order
124+
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype
121125
):
122126
bmkey = jr.PRNGKey(5678)
123127
sde_key = jr.PRNGKey(11)
@@ -127,7 +131,7 @@ def test_sde_strong_limit(
127131
t1 = 5.3
128132

129133
if noise == "add":
130-
sde = get_time_sde(t0, t1, jnp.float64, sde_key, noise_dim=3)
134+
sde = get_time_sde(t0, t1, dtype, sde_key, noise_dim=3)
131135
level_fine = 12
132136
if theoretical_order <= 1.0:
133137
level_coarse = 11
@@ -141,7 +145,7 @@ def test_sde_strong_limit(
141145
noise_dim = 5
142146
else:
143147
assert False
144-
sde = get_mlp_sde(t0, t1, jnp.float64, sde_key, noise_dim=noise_dim)
148+
sde = get_mlp_sde(t0, t1, dtype, sde_key, noise_dim=noise_dim)
145149

146150
# Reference solver is always an ODE-viable solver, so its implementation has been
147151
# verified by the ODE tests like test_ode_order.
@@ -210,9 +214,12 @@ def get_matrix(y_leaf):
210214

211215
@pytest.mark.parametrize("shape", [(), (5, 2)])
212216
@pytest.mark.parametrize("solver_ctr", _solvers())
213-
def test_sde_solver_shape(shape, solver_ctr):
217+
@pytest.mark.parametrize(
218+
"dtype",
219+
(jnp.float64, jnp.complex128),
220+
)
221+
def test_sde_solver_shape(shape, solver_ctr, dtype):
214222
pytree = ({"a": 0, "b": [0, 0]}, 0, 0)
215-
dtype = jnp.float64
216223
key = jr.PRNGKey(0)
217224
y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree)
218225
t0, t1, dt0 = 0.0, 1.0, 0.3
@@ -236,8 +243,7 @@ def test_sde_solver_shape(shape, solver_ctr):
236243
assert leaf[0].shape == shape
237244

238245

239-
def _weakly_diagonal_noise_helper(solver):
240-
dtype = jnp.float64
246+
def _weakly_diagonal_noise_helper(solver, dtype):
241247
w_shape = (3,)
242248
args = (0.5, 1.2)
243249

@@ -265,9 +271,17 @@ def _drift(t, y, args):
265271

266272

267273
@pytest.mark.parametrize("solver_ctr", _solvers())
268-
def test_weakly_diagonal_noise(solver_ctr):
269-
_weakly_diagonal_noise_helper(solver_ctr())
274+
@pytest.mark.parametrize(
275+
"dtype",
276+
(jnp.float64, jnp.complex128),
277+
)
278+
def test_weakly_diagonal_noise(solver_ctr, dtype):
279+
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
270280

271281

272-
def test_halfsolver_term_compatible():
273-
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()))
282+
@pytest.mark.parametrize(
283+
"dtype",
284+
(jnp.float64, jnp.complex128),
285+
)
286+
def test_halfsolver_term_compatible(dtype):
287+
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)

0 commit comments

Comments
 (0)