@@ -116,8 +116,12 @@ def get_dt_and_controller(level):
116116# using a single reference solution. We use Euler if the solver is Ito
117117# and Heun if the solver is Stratonovich.
118118@pytest .mark .parametrize ("solver_ctr,noise,theoretical_order" , _solvers_and_orders ())
119+ @pytest .mark .parametrize (
120+ "dtype" ,
121+ (jnp .float64 ,),
122+ )
119123def test_sde_strong_limit (
120- solver_ctr , noise : Literal ["any" , "com" , "add" ], theoretical_order
124+ solver_ctr , noise : Literal ["any" , "com" , "add" ], theoretical_order , dtype
121125):
122126 bmkey = jr .PRNGKey (5678 )
123127 sde_key = jr .PRNGKey (11 )
@@ -127,7 +131,7 @@ def test_sde_strong_limit(
127131 t1 = 5.3
128132
129133 if noise == "add" :
130- sde = get_time_sde (t0 , t1 , jnp . float64 , sde_key , noise_dim = 3 )
134+ sde = get_time_sde (t0 , t1 , dtype , sde_key , noise_dim = 3 )
131135 level_fine = 12
132136 if theoretical_order <= 1.0 :
133137 level_coarse = 11
@@ -141,7 +145,7 @@ def test_sde_strong_limit(
141145 noise_dim = 5
142146 else :
143147 assert False
144- sde = get_mlp_sde (t0 , t1 , jnp . float64 , sde_key , noise_dim = noise_dim )
148+ sde = get_mlp_sde (t0 , t1 , dtype , sde_key , noise_dim = noise_dim )
145149
146150 # Reference solver is always an ODE-viable solver, so its implementation has been
147151 # verified by the ODE tests like test_ode_order.
@@ -210,9 +214,12 @@ def get_matrix(y_leaf):
210214
211215@pytest .mark .parametrize ("shape" , [(), (5 , 2 )])
212216@pytest .mark .parametrize ("solver_ctr" , _solvers ())
213- def test_sde_solver_shape (shape , solver_ctr ):
217+ @pytest .mark .parametrize (
218+ "dtype" ,
219+ (jnp .float64 , jnp .complex128 ),
220+ )
221+ def test_sde_solver_shape (shape , solver_ctr , dtype ):
214222 pytree = ({"a" : 0 , "b" : [0 , 0 ]}, 0 , 0 )
215- dtype = jnp .float64
216223 key = jr .PRNGKey (0 )
217224 y0 = jtu .tree_map (lambda _ : jr .normal (key , shape , dtype = dtype ), pytree )
218225 t0 , t1 , dt0 = 0.0 , 1.0 , 0.3
@@ -236,8 +243,7 @@ def test_sde_solver_shape(shape, solver_ctr):
236243 assert leaf [0 ].shape == shape
237244
238245
239- def _weakly_diagonal_noise_helper (solver ):
240- dtype = jnp .float64
246+ def _weakly_diagonal_noise_helper (solver , dtype ):
241247 w_shape = (3 ,)
242248 args = (0.5 , 1.2 )
243249
@@ -265,9 +271,17 @@ def _drift(t, y, args):
265271
266272
267273@pytest .mark .parametrize ("solver_ctr" , _solvers ())
268- def test_weakly_diagonal_noise (solver_ctr ):
269- _weakly_diagonal_noise_helper (solver_ctr ())
274+ @pytest .mark .parametrize (
275+ "dtype" ,
276+ (jnp .float64 , jnp .complex128 ),
277+ )
278+ def test_weakly_diagonal_noise (solver_ctr , dtype ):
279+ _weakly_diagonal_noise_helper (solver_ctr (), dtype )
270280
271281
272- def test_halfsolver_term_compatible ():
273- _weakly_diagonal_noise_helper (diffrax .HalfSolver (diffrax .SPaRK ()))
282+ @pytest .mark .parametrize (
283+ "dtype" ,
284+ (jnp .float64 , jnp .complex128 ),
285+ )
286+ def test_halfsolver_term_compatible (dtype ):
287+ _weakly_diagonal_noise_helper (diffrax .HalfSolver (diffrax .SPaRK ()), dtype )
0 commit comments