@@ -865,6 +865,20 @@ def remat(
865865 static_argnums : int | tuple [int , ...] = (),
866866 policy : tp .Callable [..., bool ] | None = None ,
867867) -> F | tp .Callable [[F ], F ]:
868+ """A 'lifted' version of the
869+ `jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
870+ (a.k.a. ``jax.remat``).
871+
872+ ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
873+ example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
874+ how they are recomputed during the backward pass, trading off memory and FLOPs.
875+
876+ Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
877+
878+ To learn about ``jax.remat``, go to JAX's
879+ `fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
880+ and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
881+ """
868882 if isinstance (f , Missing ):
869883 return functools .partial (
870884 remat ,
@@ -886,18 +900,3 @@ def remat(
886900 ),
887901 )
888902 )
889- """A 'lifted' version of the
890- `jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
891- (a.k.a. ``jax.remat``).
892-
893- ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
894- example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
895- how they are recomputed during the backward pass, trading off memory and FLOPs.
896-
897- Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
898-
899- To learn about ``jax.remat``, go to JAX's
900- `fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
901- and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
902- """
903-
0 commit comments