Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs_nnx/api_reference/flax.nnx/variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ variables
:members:
.. autoclass:: VariableMetadata
:members:
.. autoclass:: VariableState
:members:

.. autofunction:: with_metadata

Expand Down
26 changes: 13 additions & 13 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs_nnx/guides/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ checkpointer.save(ckpt_dir / 'state', state)

## Restore checkpoints

Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.
Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.

At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:
- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.
Expand Down
54 changes: 20 additions & 34 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" 'a': Param(\n",
" value=0\n",
" )\n",
"})\n",
"batch_stats = State({\n",
" 'b': VariableState(\n",
" type=BatchStat,\n",
" 'b': BatchStat(\n",
" value=True\n",
" )\n",
"})\n"
Expand Down Expand Up @@ -103,27 +101,23 @@
"name": "stdout",
"output_type": "stream",
"text": [
"is_param((), nnx.Param(0)) = True\n",
"is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n"
"is_param((), nnx.Param(0)) = True\n"
]
}
],
"source": [
"def is_param(path, value) -> bool:\n",
" return isinstance(value, nnx.Param) or (\n",
" hasattr(value, 'type') and issubclass(value.type, nnx.Param)\n",
" )\n",
" return isinstance(value, nnx.Param)\n",
"\n",
"print(f'{is_param((), nnx.Param(0)) = }')\n",
"print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')"
"print(f'{is_param((), nnx.Param(0)) = }')"
]
},
{
"cell_type": "markdown",
"id": "a8a2641e",
"metadata": {},
"source": [
"Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:"
"Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:"
]
},
{
Expand All @@ -136,16 +130,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"is_param((), nnx.Param(0)) = True\n",
"is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n"
"is_param((), nnx.Param(0)) = True\n"
]
}
],
"source": [
"is_param = nnx.OfType(nnx.Param)\n",
"\n",
"print(f'{is_param((), nnx.Param(0)) = }')\n",
"print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')"
"print(f'{is_param((), nnx.Param(0)) = }')"
]
},
{
Expand Down Expand Up @@ -207,18 +199,18 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"id": "7e065fa9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"is_param = OfType(<class 'flax.nnx.nnx.variables.Param'>)\n",
"is_param = OfType(<class 'flax.nnx.variablelib.Param'>)\n",
"everything = Everything()\n",
"nothing = Nothing()\n",
"params_or_dropout = Any(OfType(<class 'flax.nnx.nnx.variables.Param'>), WithTag('dropout'))\n"
"params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))\n"
]
}
],
Expand Down Expand Up @@ -252,7 +244,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "068208fc",
"metadata": {},
"outputs": [
Expand All @@ -261,14 +253,12 @@
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" 'a': Param(\n",
" value=0\n",
" )\n",
"})\n",
"batch_stats = State({\n",
" 'b': VariableState(\n",
" type=BatchStat,\n",
" 'b': BatchStat(\n",
" value=True\n",
" )\n",
"})\n"
Expand Down Expand Up @@ -323,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"id": "014da4d4",
"metadata": {},
"outputs": [
Expand All @@ -332,12 +322,10 @@
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" 'a': Param(\n",
" value=0\n",
" ),\n",
" 'b': VariableState(\n",
" type=SpecialParam,\n",
" 'b': SpecialParam(\n",
" value=0\n",
" )\n",
"})\n",
Expand Down Expand Up @@ -371,7 +359,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"id": "a2ebf5b2",
"metadata": {},
"outputs": [
Expand All @@ -380,14 +368,12 @@
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" 'a': Param(\n",
" value=0\n",
" )\n",
"})\n",
"special_params = State({\n",
" 'b': VariableState(\n",
" type=SpecialParam,\n",
" 'b': SpecialParam(\n",
" value=0\n",
" )\n",
"})\n"
Expand Down
8 changes: 2 additions & 6 deletions docs_nnx/guides/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,17 @@ Types are not functions of this form. They are treated as `Filter`s because, as

```{code-cell} ipython3
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param) or (
hasattr(value, 'type') and issubclass(value.type, nnx.Param)
)
return isinstance(value, nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```

Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:
Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:

```{code-cell} ipython3
is_param = nnx.OfType(nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```

## The `Filter` DSL
Expand Down
Loading
Loading