JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome Equinox JAX projects, and other resources. Contributions are welcome!
Official examples can be found in the
Models:
- 
PaLM-jax - Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
 - 
mistral_jax - This is a port of Mistral-7B model in JAX
 
Projects and Packages:
- 
levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
 - 
eqxvision - A Python package of computer vision models for the Equinox ecosystem.
 - 
haliax - Named Tensors for Legible Deep Learning in JAX
 - 
diffrax: numerical differential equation solvers.
 - 
lineax - Linear solvers in JAX and Equinox.
 - 
optimistix - Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox.
 - 
sympy2jax - Turn SymPy expressions into trainable JAX expressions.
 - 
flowMC - Normalizing-flow enhanced sampling package for probabilistic inference in Jax
 - 
flowjax - FlowJax: Distributions and Normalizing Flows in Jax
 - 
traceax - Traceax: Stochastic trace estimation using JAX
 - 
galax - Galactic and Gravitational Dynamics in Python (+ GPU and autodiff)
 - 
coordinax - Coordinates in JAX
 - 
unxt - Unitful Quantities in JAX
 - 
statedict2pytree - Transforming a PyTorch model into an JAX PyTree
 
Always useful
jaxtyping: type annotations for shape/dtype of arrays.
Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Scientific computing
BlackJAX: probabilistic+Bayesian sampling.
PySR: symbolic regression. (Non-JAX honourable mention!)
Awesome JAX
Awesome JAX: a longer list of other JAX projects.
Contributions welcome! Read the contribution guidelines first.
Repository inspired by:
