This is a Jax implementation of the differentiable JPEG compression algorithm, based on the PyTorch implementation and some of the modifications found in this repository to improve quality at high compression rates.
- JAX
Can be installed with pip:
pip install diffjpeg_jaxUnlike the PyTorch version, this is ML library agnostic, so it simply is implemented as a function. Inputs should be in the range [0, 255] and in the format (H, W, C).
from diffjpeg_jax import diff_jpeg
img = ... # (H, W, C)
jpeg = diff_jpeg(img, quality=75)Note: The implementation is not wrapped in JIT, so make sure to do that if you want to. For batch processing just use vmap.