@@ -45,7 +45,10 @@ def test_fixed_point(solver, _fn, init, args):
4545
4646@pytest .mark .parametrize ("solver" , _fp_solvers )
4747@pytest .mark .parametrize ("_fn, init, args" , fixed_point_fn_init_args )
48- def test_fixed_point_jvp (getkey , solver , _fn , init , args ):
48+ @pytest .mark .parametrize ("dtype" , [jnp .float64 , jnp .complex128 ])
49+ def test_fixed_point_jvp (getkey , solver , _fn , init , dtype , args ):
50+ args = jtu .tree_map (lambda x : x .astype (dtype ), args )
51+ init = jtu .tree_map (lambda x : x .astype (dtype ), init )
4952 atol = rtol = 1e-3
5053 has_aux = random .choice ([True , False ])
5154 if has_aux :
@@ -54,8 +57,10 @@ def test_fixed_point_jvp(getkey, solver, _fn, init, args):
5457 fn = _fn
5558
5659 dynamic_args , static_args = eqx .partition (args , eqx .is_array )
57- t_init = jtu .tree_map (lambda x : jr .normal (getkey (), x .shape ), init )
58- t_dynamic_args = jtu .tree_map (lambda x : jr .normal (getkey (), x .shape ), dynamic_args )
60+ t_init = jtu .tree_map (lambda x : jr .normal (getkey (), x .shape , dtype = dtype ), init )
61+ t_dynamic_args = jtu .tree_map (
62+ lambda x : jr .normal (getkey (), x .shape , dtype = dtype ), dynamic_args
63+ )
5964
6065 def fixed_point (x , dynamic_args , * , adjoint ):
6166 args = eqx .combine (dynamic_args , static_args )
0 commit comments