Skip to content

Commit a3f5be2

Browse files
author
Flax Authors
committed
Merge pull request #4790 from google:enable-docs-nnx-remat
PiperOrigin-RevId: 781707304
2 parents 188baab + 4870afe commit a3f5be2

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

flax/nnx/transforms/autodiff.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,20 @@ def remat(
865865
static_argnums: int | tuple[int, ...] = (),
866866
policy: tp.Callable[..., bool] | None = None,
867867
) -> F | tp.Callable[[F], F]:
868+
"""A 'lifted' version of the
869+
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
870+
(a.k.a. ``jax.remat``).
871+
872+
``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
873+
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
874+
how they are recomputed during the backward pass, trading off memory and FLOPs.
875+
876+
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
877+
878+
To learn about ``jax.remat``, go to JAX's
879+
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
880+
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
881+
"""
868882
if isinstance(f, Missing):
869883
return functools.partial(
870884
remat,
@@ -886,18 +900,3 @@ def remat(
886900
),
887901
)
888902
)
889-
"""A 'lifted' version of the
890-
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
891-
(a.k.a. ``jax.remat``).
892-
893-
``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
894-
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
895-
how they are recomputed during the backward pass, trading off memory and FLOPs.
896-
897-
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
898-
899-
To learn about ``jax.remat``, go to JAX's
900-
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
901-
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
902-
"""
903-

0 commit comments

Comments
 (0)