Skip to content

JAX support #35

@crusaderky

Description

@crusaderky

Blockers for JAX support, as of JAX 0.5.2:

Sobol

jnp.bitwise_xor.accumulate is quite slow on CPU and extremely slow on GPU. Alternative, specialized implementations exist; however not being able to query the device inside jax.jit means that one can't write them into pyscenarios without resorting to cumbersome parameters or module-level config flags.

T* copula

Missing distribution functions in jax.scipy:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions