@@ -91,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
9191 assert len (x1 ) == 2
9292 dt0 , w0 = x0
9393 dt1 , w1 = x1
94- su = jnp .asarray (dt1 - dt0 , dtype = w0 .dtype )
94+ su = jnp .asarray (dt1 - dt0 , dtype = complex_to_real_dtype ( w0 .dtype ) )
9595 return BrownianIncrement (dt = su , W = w1 - w0 )
9696
9797 elif len (x0 ) == 4 : # space-time levy area case
@@ -100,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
100100 dt1 , w1 , hh1 , bhh1 = x1
101101
102102 w_su = w1 - w0
103- su = jnp .asarray (dt1 - dt0 , dtype = w0 .dtype )
103+ su = jnp .asarray (dt1 - dt0 , dtype = complex_to_real_dtype ( w0 .dtype ) )
104104 _su = jnp .where (jnp .abs (su ) < jnp .finfo (su ).eps , jnp .inf , su )
105105 inverse_su = 1 / _su
106- u_bb_s = dt1 * w0 - dt0 * w1
107- bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
108- hh_su = inverse_su * bhh_su
106+ with jax .numpy_dtype_promotion ("standard" ):
107+ u_bb_s = dt1 * w0 - dt0 * w1
108+ bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
109+ hh_su = inverse_su * bhh_su
109110 return SpaceTimeLevyArea (dt = su , W = w_su , H = hh_su )
110111 else :
111112 assert False
@@ -396,27 +397,31 @@ def _body_fun(_state: _State):
396397 a = d_prime * sr3 * sr_ru_half
397398 b = d_prime * ru3 * sr_ru_half
398399
399- w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
400- w_r = w_s + w_sr
401- c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
402- bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
403- bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
400+ with jax .numpy_dtype_promotion ("standard" ):
401+ w_sr = (
402+ sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
403+ )
404+ w_r = w_s + w_sr
405+ c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
406+ bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
407+ bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
404408
405- inverse_r = 1 / jnp .where (jnp .abs (r ) < jnp .finfo (r ).eps , jnp .inf , r )
406- hh_r = inverse_r * bhh_r
409+ inverse_r = 1 / jnp .where (jnp .abs (r ) < jnp .finfo (r ).eps , jnp .inf , r )
410+ hh_r = inverse_r * bhh_r
407411
408412 elif self .levy_area is BrownianIncrement :
409- w_mean = w_s + sr / su * w_su
410- if self ._spline == "sqrt" :
411- z = jr .normal (final_state .key , shape , dtype )
412- bb = jnp .sqrt (sr * ru / su ) * z
413- elif self ._spline == "quad" :
414- z = jr .normal (final_state .key , shape , dtype )
415- bb = (sr * ru / su ) * z
416- elif self ._spline == "zero" :
417- bb = jnp .zeros (shape , dtype )
418- else :
419- assert False
413+ with jax .numpy_dtype_promotion ("standard" ):
414+ w_mean = w_s + sr / su * w_su
415+ if self ._spline == "sqrt" :
416+ z = jr .normal (final_state .key , shape , dtype )
417+ bb = jnp .sqrt (sr * ru / su ) * z
418+ elif self ._spline == "quad" :
419+ z = jr .normal (final_state .key , shape , dtype )
420+ bb = (sr * ru / su ) * z
421+ elif self ._spline == "zero" :
422+ bb = jnp .zeros (shape , dtype )
423+ else :
424+ assert False
420425 w_r = w_mean + bb
421426 return r , w_r
422427
@@ -499,8 +504,8 @@ def _brownian_arch(
499504
500505 w_t = w_s + w_st
501506 w_stu = (w_s , w_t , w_u )
502-
503- bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t )
507+ with jax . numpy_dtype_promotion ( "standard" ):
508+ bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t )
504509 bhh_stu = (bhh_s , bhh_t , bhh_u )
505510 bkk_stu = None
506511 bkk_st_tu = None
0 commit comments