Important Version
0.2.0includes a pure JAX backend that no longer requires MPI. For multi-node runs, MPI and NCCL backends are still available through cuDecomp.
JAX reimplementation and bindings for NVIDIA's cuDecomp library (Romero et al. 2022), enabling multi-node parallel FFTs and halo exchanges directly in low-level NCCL/CUDA-Aware MPI from your JAX code.
Important Starting from version 0.2.8, jaxDecomp supports JAX's Shardy partitioner, which can be activated via
jax.config.update('jax_use_shardy_partitioner', True). This partitioner is enabled by default in JAX 0.7.x and later versions. Shardy support is an internal implementation change and users should not expect any behavioral differences outside of what the JAX sharding mechanism provides, as explained in the JAX Shardy migration documentation.
Below is a simple code snippet illustrating how to perform a 3D FFT on a distributed 3D array, followed by a halo exchange. For demonstration purposes, we force 8 CPU devices via environment variables:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jaxdecomp
# Create a 2x4 mesh of devices on CPU
pdims = (2, 4)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
# Create a random 3D array and enforce sharding
a = jax.random.normal(jax.random.PRNGKey(0), (1024, 1024, 1024))
a = jax.lax.with_sharding_constraint(a, sharding)
# Parallel FFTs
k_array = jaxdecomp.fft.pfft3d(a)
rec_array = jaxdecomp.fft.pifft3d(a)
# Parallel halo exchange
exchanged = jaxdecomp.halo_exchange(a, halo_extents=(16, 16), halo_periods=(True, True))All these functions are JIT-compatible and support automatic differentiation (with some caveats).
See also:
Important Multi-node FFTs work with both JAX and cuDecomp backends
For CPU with JAX, Multi-node is supported starting JAX v0.5.1 (withgloobackend)
On HPC clusters (e.g., Jean Zay, Perlmutter), you typically launch your script with:
srun python your_script.pyor
mpirun -n 8 python your_script.pySee the Slurm README and template script for more details.
For other features, compile and install with cuDecomp enabled as described in install:
import jaxdecomp
# Optionally select communication backends (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
# Then specify 'backend="cudecomp"' in your FFT or halo calls:
karray = jaxdecomp.fft.pfft3d(global_array, backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray, backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp'
)jaxDecomp is on PyPI:
- Install the appropriate JAX wheel:
- GPU:
pip install --upgrade "jax[cuda]" - CPU:
pip install --upgrade "jax[cpu]"
- GPU:
- Install
jaxdecomp:pip install jaxdecomp
This setup uses the pure-JAX backend—no MPI required.
If you need to use MPI instead of NCCL for GPU, you can build from GitHub with cuDecomp enabled. This requires the NVIDIA HPC SDK. Ensure nvc, nvc++, and nvcc are in your PATH, CUDA, MPI, and NCCL shared libraries are on LD_LIBRARY_PATH, and set CC=nvc and CXX=nvc++ before building.
pip install -U pip
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ONAlternatively, clone the repository locally and install from your checkout:
git clone https://github.com/DifferentiableUniverseInitiative/jaxDecomp.git --recursive
cd jaxDecomp
pip install -U pip
pip install . -Ccmake.define.JD_CUDECOMP_BACKEND=ON- If CMake cannot find NVHPC, set:
and then install again.
export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
IDRIS Jean Zay HPE SGI 8600 supercomputer
As of February 2025, loading modules in this exact order works:
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Install JAX
pip install --upgrade "jax[cuda]"
# Install jaxDecomp with cuDecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # sometimes needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ONNote: If using only the pure-JAX backend, you do not need NVHPC.
NERSC Perlmutter HPE Cray EX supercomputer
As of November 2022:
module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Install JAX
pip install --upgrade "jax[cuda]"
# Install jaxDecomp w/ cuDecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ONBy default, cuDecomp uses NCCL for inter-device communication. You can customize this at runtime:
import jaxdecomp
# Choose MPI or NVSHMEM for halo and transpose ops
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)This can also be managed via environment variables, as described in the docs.
The cuDecomp library can autotune the partition layout to maximize performance:
automesh = jaxdecomp.autotune(shape=[512,512,512])
# 'automesh' is an optimized partition layout.
# You can then create a JAX Sharding spec from this:
from jax.sharding import PositionalSharding
sharding = PositionalSharding(automesh)License: This project is licensed under the MIT License.
For more details, see the examples directory and the documentation. Contributions and issues are welcome!