-
Breaking changes:
- JAX is changing the default
jax.pmapimplementation to one implemented in
terms ofjax.jitandjax.shard_map.jax.pmapis in maintenance mode
and we encourage all new code to usejax.shard_mapdirectly. See the
migration guide for
more information. - The
auto=parameter ofjax.experimental.shard_map.shard_maphas been
removed. This means thatjax.experimental.shard_map.shard_mapno longer
supports nesting. If you want to nest shard_map calls, please use
jax.shard_map. - JAX no longer allows passing objects that support
__jax_array__directly
to, e.g.jit-ed functions. Calljax.numpy.asarrayon them first. jax.numpy.covis now returns NaN for empty arrays ({jax-issue}#32305),
and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).- JAX no longer accepts
Arrayvalues where adtypevalue is expected. Call
.dtypeon these values first. - The deprecated function
jax.interpreters.mlir.custom_callwas
removed. - The
jax.util,jax.extend.ffi, andjax.experimental.host_callback
modules have been removed. All public APIs within these modules were
deprecated and removed in v0.7.0 or earlier. - The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_p
was removed. jax.experimental.multihost_utils.process_allgatherraises an error when
the input is a jax.Array and not fully-addressable andtiled=False. To fix
this, passtiled=Trueto yourprocess_allgatherinvocation.- from
jax.experimental.compilation_cache, the deprecated symbols
is_initializedandinitialize_cachewere removed. - The deprecated function
jax.interpreters.xla.canonicalize_dtype
was removed. jaxlib.hlo_helpershas been removed. Usejax.ffiinstead.- The option
jax_cpu_enable_gloo_collectiveshas been removed. Use
jax_cpu_collectives_implementationinstead. - The previously-deprecated
interpolationargument to
jax.numpy.percentileandjax.numpy.quantilehas been
removed; usemethodinstead. - The JAX-internal
for_loopprimitive was removed. Its functionality,
reading from and writing to refs in the loop body, is now directly
supported byjax.lax.fori_loop. If you need help updating your
code, please file a bug. jax.numpy.trimzerosnow errors for non-1D input.- The
whereargument tojax.numpy.sumand other reductions is now
required to be boolean. Non-boolean values have resulted in a
DeprecationWarningsince JAX v0.5.0. - The deprecated functions in
jax.dlpack,jax.errors,
jax.lib.xla_bridge,jax.lib.xla_client, and
jax.lib.xla_extensionwere removed. jax.interpreters.mlir.dense_bool_arraywas removed. Use MLIR APIs to
construct attributes instead.
- JAX is changing the default
-
Changes
jax.numpy.linalg.eignow returns a namedtuple (with attributes
eigenvaluesandeigenvectors) instead of a plain tuple.jax.gradandjax.vjpwill now round always primals to
float32iffloat64mode is not enabled.jax.dlpack.from_dlpacknow accepts arrays with non-default layouts,
for example, transposed.- The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
newimplementationargument tojax.lax.linalg.eig
({jax-issue}#27265). Theuse_magmaargument is now deprecated in favor
ofimplementation. jax.numpy.trim_zerosnow follows NumPy 2.2 in supporting
multi-dimensional inputs.
-
Deprecations
jax.experimental.enable_x64andjax.experimental.disable_x64
are deprecated in favor of the new non-experimental context manager
jax.enable_x64.jax.experimental.shard_map.shard_mapis deprecated; going forward use
jax.shard_map.jax.experimental.pjit.pjitis deprecated; going forward use
jax.jit.