1010import jax .random as jr
1111import jax .tree_util as jtu
1212import lineax .internal as lxi
13- from jaxtyping import Array , Float , PRNGKeyArray , PyTree
13+ from jaxtyping import Array , Inexact , PRNGKeyArray , PyTree
14+ from lineax .internal import complex_to_real_dtype
1415
1516from .._custom_types import (
1617 AbstractBrownianIncrement ,
5455# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
5556# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.
5657
57- FloatDouble : TypeAlias = tuple [Float [Array , " *shape" ], Float [Array , " *shape" ]]
58+ FloatDouble : TypeAlias = tuple [Inexact [Array , " *shape" ], Inexact [Array , " *shape" ]]
5859FloatTriple : TypeAlias = tuple [
59- Float [Array , " *shape" ], Float [Array , " *shape" ], Float [Array , " *shape" ]
60+ Inexact [Array , " *shape" ], Inexact [Array , " *shape" ], Inexact [Array , " *shape" ]
6061]
6162_Spline : TypeAlias = Literal ["sqrt" , "quad" , "zero" ]
6263_BrownianReturn = TypeVar ("_BrownianReturn" , bound = AbstractBrownianIncrement )
@@ -90,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
9091 assert len (x1 ) == 2
9192 dt0 , w0 = x0
9293 dt1 , w1 = x1
93- su = jnp .asarray (dt1 - dt0 , dtype = w0 .dtype )
94+ su = jnp .asarray (dt1 - dt0 , dtype = complex_to_real_dtype ( w0 .dtype ) )
9495 return BrownianIncrement (dt = su , W = w1 - w0 )
9596
9697 elif len (x0 ) == 4 : # space-time levy area case
@@ -99,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
99100 dt1 , w1 , hh1 , bhh1 = x1
100101
101102 w_su = w1 - w0
102- su = jnp .asarray (dt1 - dt0 , dtype = w0 .dtype )
103+ su = jnp .asarray (dt1 - dt0 , dtype = complex_to_real_dtype ( w0 .dtype ) )
103104 _su = jnp .where (jnp .abs (su ) < jnp .finfo (su ).eps , jnp .inf , su )
104105 inverse_su = 1 / _su
105- u_bb_s = dt1 * w0 - dt0 * w1
106- bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
107- hh_su = inverse_su * bhh_su
106+ with jax .numpy_dtype_promotion ("standard" ):
107+ u_bb_s = dt1 * w0 - dt0 * w1
108+ bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
109+ hh_su = inverse_su * bhh_su
108110 return SpaceTimeLevyArea (dt = su , W = w_su , H = hh_su )
109111 else :
110112 assert False
@@ -283,9 +285,10 @@ def _evaluate_leaf(
283285 tuple [RealScalarLike , Array ], tuple [RealScalarLike , Array , Array , Array ]
284286 ]:
285287 shape , dtype = struct .shape , struct .dtype
288+ tdtype = complex_to_real_dtype (dtype )
286289
287- t0 = jnp .zeros ((), dtype )
288- r = jnp .asarray (r , dtype )
290+ t0 = jnp .zeros ((), tdtype )
291+ r = jnp .asarray (r , tdtype )
289292
290293 if self .levy_area is SpaceTimeLevyArea :
291294 state_key , init_key_w , init_key_la = jr .split (key , 3 )
@@ -394,27 +397,31 @@ def _body_fun(_state: _State):
394397 a = d_prime * sr3 * sr_ru_half
395398 b = d_prime * ru3 * sr_ru_half
396399
397- w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
398- w_r = w_s + w_sr
399- c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
400- bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
401- bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
400+ with jax .numpy_dtype_promotion ("standard" ):
401+ w_sr = (
402+ sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
403+ )
404+ w_r = w_s + w_sr
405+ c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
406+ bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
407+ bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
402408
403- inverse_r = 1 / jnp .where (jnp .abs (r ) < jnp .finfo (r ).eps , jnp .inf , r )
404- hh_r = inverse_r * bhh_r
409+ inverse_r = 1 / jnp .where (jnp .abs (r ) < jnp .finfo (r ).eps , jnp .inf , r )
410+ hh_r = inverse_r * bhh_r
405411
406412 elif self .levy_area is BrownianIncrement :
407- w_mean = w_s + sr / su * w_su
408- if self ._spline == "sqrt" :
409- z = jr .normal (final_state .key , shape , dtype )
410- bb = jnp .sqrt (sr * ru / su ) * z
411- elif self ._spline == "quad" :
412- z = jr .normal (final_state .key , shape , dtype )
413- bb = (sr * ru / su ) * z
414- elif self ._spline == "zero" :
415- bb = jnp .zeros (shape , dtype )
416- else :
417- assert False
413+ with jax .numpy_dtype_promotion ("standard" ):
414+ w_mean = w_s + sr / su * w_su
415+ if self ._spline == "sqrt" :
416+ z = jr .normal (final_state .key , shape , dtype )
417+ bb = jnp .sqrt (sr * ru / su ) * z
418+ elif self ._spline == "quad" :
419+ z = jr .normal (final_state .key , shape , dtype )
420+ bb = (sr * ru / su ) * z
421+ elif self ._spline == "zero" :
422+ bb = jnp .zeros (shape , dtype )
423+ else :
424+ assert False
418425 w_r = w_mean + bb
419426 return r , w_r
420427
@@ -497,8 +504,8 @@ def _brownian_arch(
497504
498505 w_t = w_s + w_st
499506 w_stu = (w_s , w_t , w_u )
500-
501- bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t )
507+ with jax . numpy_dtype_promotion ( "standard" ):
508+ bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t )
502509 bhh_stu = (bhh_s , bhh_t , bhh_u )
503510 bkk_stu = None
504511 bkk_st_tu = None
0 commit comments