-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Blockers for JAX support, as of JAX 0.5.2:
Sobol
jnp.bitwise_xor.accumulate is quite slow on CPU and extremely slow on GPU. Alternative, specialized implementations exist; however not being able to query the device inside jax.jit means that one can't write them into pyscenarios without resorting to cumbersome parameters or module-level config flags.
-
accumulatepoor performance jax-ml/jax#28097 - Missing
.deviceattribute inside@jax.jitjax-ml/jax#26000 (comment)
T* copula
Missing distribution functions in jax.scipy:
-
jax.scipy.stats.chi2.ppf(one could usejax.random.chisquareon JAX own's PRNG, but not on Sobol) -
jax.scipy.stats.t.cdf- read below -
jax.scipy.stats.norm.ppf -
jax.scipy.special.betaincis extremely slowspecial.betainc,stdtr, andstdtritare very slow jax-ml/jax#28547
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels