From e8984b709e61d9d4c4b3f1c3fa859cbb93c47e58 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 2 Jul 2025 21:46:51 -0700 Subject: [PATCH] [nnx] remove VariableState --- docs_nnx/api_reference/flax.nnx/variables.rst | 2 - docs_nnx/guides/checkpointing.ipynb | 26 +- docs_nnx/guides/checkpointing.md | 2 +- docs_nnx/guides/filters_guide.ipynb | 54 +-- docs_nnx/guides/filters_guide.md | 8 +- docs_nnx/guides/flax_gspmd.ipynb | 356 +++++++++--------- docs_nnx/guides/flax_gspmd.md | 4 +- docs_nnx/guides/haiku_to_flax.rst | 12 +- docs_nnx/guides/linen_to_nnx.rst | 12 +- docs_nnx/guides/surgery.ipynb | 64 ++-- docs_nnx/guides/surgery.md | 18 +- docs_nnx/nnx_glossary.rst | 3 - .../nnx_toy_examples/10_fsdp_and_optimizer.py | 12 +- .../nnx_toy_examples/mutable_array_demo.py | 8 +- flax/nnx/__init__.py | 1 - flax/nnx/bridge/module.py | 18 +- flax/nnx/bridge/variables.py | 19 +- flax/nnx/bridge/wrappers.py | 10 +- flax/nnx/filterlib.py | 4 +- flax/nnx/graph.py | 275 +++++--------- flax/nnx/nn/linear.py | 9 +- flax/nnx/nn/normalization.py | 29 +- flax/nnx/object.py | 4 +- flax/nnx/spmd.py | 50 +-- flax/nnx/statelib.py | 10 +- flax/nnx/summary.py | 6 +- flax/nnx/training/optimizer.py | 118 ++---- flax/nnx/transforms/autodiff.py | 28 +- flax/nnx/transforms/iteration.py | 52 +-- flax/nnx/variablelib.py | 229 ++--------- tests/nnx/graph_utils_test.py | 51 +-- tests/nnx/module_test.py | 22 +- tests/nnx/nn/attention_test.py | 4 +- tests/nnx/optimizer_test.py | 4 +- tests/nnx/state_test.py | 22 +- tests/nnx/transforms_test.py | 13 +- tests/nnx/variable_test.py | 4 +- 37 files changed, 594 insertions(+), 969 deletions(-) diff --git a/docs_nnx/api_reference/flax.nnx/variables.rst b/docs_nnx/api_reference/flax.nnx/variables.rst index 83f81c546..c135ae3bd 100644 --- a/docs_nnx/api_reference/flax.nnx/variables.rst +++ b/docs_nnx/api_reference/flax.nnx/variables.rst @@ -16,8 +16,6 @@ variables :members: .. autoclass:: VariableMetadata :members: -.. autoclass:: VariableState - :members: .. autofunction:: with_metadata diff --git a/docs_nnx/guides/checkpointing.ipynb b/docs_nnx/guides/checkpointing.ipynb index af32f00a9..243fc404a 100644 --- a/docs_nnx/guides/checkpointing.ipynb +++ b/docs_nnx/guides/checkpointing.ipynb @@ -88,7 +88,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -100,7 +100,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -131,7 +131,7 @@ "\n", "## Restore checkpoints\n", "\n", - "Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n", + "Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n", "\n", "At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n", "- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n", @@ -153,7 +153,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -173,14 +173,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1251: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -192,7 +192,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -252,13 +252,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -270,7 +270,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -325,7 +325,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -338,7 +338,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -350,7 +350,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" diff --git a/docs_nnx/guides/checkpointing.md b/docs_nnx/guides/checkpointing.md index cc0101c25..3cd828bb1 100644 --- a/docs_nnx/guides/checkpointing.md +++ b/docs_nnx/guides/checkpointing.md @@ -82,7 +82,7 @@ checkpointer.save(ckpt_dir / 'state', state) ## Restore checkpoints -Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes. +Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes. At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows: - First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library. diff --git a/docs_nnx/guides/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb index e15b90dbe..045553eeb 100644 --- a/docs_nnx/guides/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -29,14 +29,12 @@ "output_type": "stream", "text": [ "params = State({\n", - " 'a': VariableState(\n", - " type=Param,\n", + " 'a': Param(\n", " value=0\n", " )\n", "})\n", "batch_stats = State({\n", - " 'b': VariableState(\n", - " type=BatchStat,\n", + " 'b': BatchStat(\n", " value=True\n", " )\n", "})\n" @@ -103,19 +101,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "is_param((), nnx.Param(0)) = True\n", - "is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n" + "is_param((), nnx.Param(0)) = True\n" ] } ], "source": [ "def is_param(path, value) -> bool:\n", - " return isinstance(value, nnx.Param) or (\n", - " hasattr(value, 'type') and issubclass(value.type, nnx.Param)\n", - " )\n", + " return isinstance(value, nnx.Param)\n", "\n", - "print(f'{is_param((), nnx.Param(0)) = }')\n", - "print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')" + "print(f'{is_param((), nnx.Param(0)) = }')" ] }, { @@ -123,7 +117,7 @@ "id": "a8a2641e", "metadata": {}, "source": [ - "Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:" + "Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:" ] }, { @@ -136,16 +130,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "is_param((), nnx.Param(0)) = True\n", - "is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n" + "is_param((), nnx.Param(0)) = True\n" ] } ], "source": [ "is_param = nnx.OfType(nnx.Param)\n", "\n", - "print(f'{is_param((), nnx.Param(0)) = }')\n", - "print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')" + "print(f'{is_param((), nnx.Param(0)) = }')" ] }, { @@ -207,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "7e065fa9", "metadata": {}, "outputs": [ @@ -215,10 +207,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "is_param = OfType()\n", + "is_param = OfType()\n", "everything = Everything()\n", "nothing = Nothing()\n", - "params_or_dropout = Any(OfType(), WithTag('dropout'))\n" + "params_or_dropout = Any(OfType(), WithTag('dropout'))\n" ] } ], @@ -252,7 +244,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "068208fc", "metadata": {}, "outputs": [ @@ -261,14 +253,12 @@ "output_type": "stream", "text": [ "params = State({\n", - " 'a': VariableState(\n", - " type=Param,\n", + " 'a': Param(\n", " value=0\n", " )\n", "})\n", "batch_stats = State({\n", - " 'b': VariableState(\n", - " type=BatchStat,\n", + " 'b': BatchStat(\n", " value=True\n", " )\n", "})\n" @@ -323,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "id": "014da4d4", "metadata": {}, "outputs": [ @@ -332,12 +322,10 @@ "output_type": "stream", "text": [ "params = State({\n", - " 'a': VariableState(\n", - " type=Param,\n", + " 'a': Param(\n", " value=0\n", " ),\n", - " 'b': VariableState(\n", - " type=SpecialParam,\n", + " 'b': SpecialParam(\n", " value=0\n", " )\n", "})\n", @@ -371,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "id": "a2ebf5b2", "metadata": {}, "outputs": [ @@ -380,14 +368,12 @@ "output_type": "stream", "text": [ "params = State({\n", - " 'a': VariableState(\n", - " type=Param,\n", + " 'a': Param(\n", " value=0\n", " )\n", "})\n", "special_params = State({\n", - " 'b': VariableState(\n", - " type=SpecialParam,\n", + " 'b': SpecialParam(\n", " value=0\n", " )\n", "})\n" diff --git a/docs_nnx/guides/filters_guide.md b/docs_nnx/guides/filters_guide.md index b4997a742..61329bb27 100644 --- a/docs_nnx/guides/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -62,21 +62,17 @@ Types are not functions of this form. They are treated as `Filter`s because, as ```{code-cell} ipython3 def is_param(path, value) -> bool: - return isinstance(value, nnx.Param) or ( - hasattr(value, 'type') and issubclass(value.type, nnx.Param) - ) + return isinstance(value, nnx.Param) print(f'{is_param((), nnx.Param(0)) = }') -print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type: +Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type: ```{code-cell} ipython3 is_param = nnx.OfType(nnx.Param) print(f'{is_param((), nnx.Param(0)) = }') -print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` ## The `Filter` DSL diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index b428c0ac3..0fa08da61 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -108,14 +108,14 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Mesh('data': 2, 'model': 4)\n" + "Mesh('data': 2, 'model': 4, axis_types=(Auto, Auto))\n" ] } ], @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -251,15 +251,15 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)\n", - "NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=unpinned_host)\n" + "NamedSharding(mesh=Mesh('data': 2, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)\n", + "NamedSharding(mesh=Mesh('data': 2, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('model',), memory_kind=unpinned_host)\n" ] } ], @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -311,31 +311,31 @@ { "data": { "text/html": [ - "
┌───────┬───────┬───────┬───────┐\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "└───────┴───────┴───────┴───────┘\n",
+       "
                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       " CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
        "
\n" ], "text/plain": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, @@ -351,27 +351,33 @@ { "data": { "text/html": [ - "
┌───────────────────────┐\n",
-       "│        CPU 0,4        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 1,5        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 2,6        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 3,7        │\n",
-       "└───────────────────────┘\n",
+       "
                         \n",
+       "         CPU 0,4         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 1,5         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 2,6         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 3,7         \n",
+       "                         \n",
        "
\n" ], "text/plain": [ - "┌───────────────────────┐\n", - "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", - "└───────────────────────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, @@ -415,37 +421,37 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
┌───────┬───────┬───────┬───────┐\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "└───────┴───────┴───────┴───────┘\n",
+       "
                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       " CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
        "
\n" ], "text/plain": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, @@ -454,27 +460,33 @@ { "data": { "text/html": [ - "
┌───────────────────────┐\n",
-       "│        CPU 0,4        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 1,5        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 2,6        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 3,7        │\n",
-       "└───────────────────────┘\n",
+       "
                         \n",
+       "         CPU 0,4         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 1,5         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 2,6         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 3,7         \n",
+       "                         \n",
        "
\n" ], "text/plain": [ - "┌───────────────────────┐\n", - "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", - "└───────────────────────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, @@ -522,7 +534,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -535,31 +547,33 @@ { "data": { "text/html": [ - "
┌──────────────────────────────────────────────────────────────────────────────┐\n",
-       "│                                                                              │\n",
-       "│                                 CPU 0,1,2,3                                  │\n",
-       "│                                                                              │\n",
-       "│                                                                              │\n",
-       "├──────────────────────────────────────────────────────────────────────────────┤\n",
-       "│                                                                              │\n",
-       "│                                 CPU 4,5,6,7                                  │\n",
-       "│                                                                              │\n",
-       "│                                                                              │\n",
-       "└──────────────────────────────────────────────────────────────────────────────┘\n",
+       "
                                                                                \n",
+       "                                                                                \n",
+       "                                  CPU 0,1,2,3                                   \n",
+       "                                                                                \n",
+       "                                                                                \n",
+       "                                                                                \n",
+       "                                                                                \n",
+       "                                                                                \n",
+       "                                  CPU 4,5,6,7                                   \n",
+       "                                                                                \n",
+       "                                                                                \n",
+       "                                                                                \n",
        "
\n" ], "text/plain": [ - "┌──────────────────────────────────────────────────────────────────────────────┐\n", - "│ │\n", - "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", - "│ │\n", - "│ │\n", - "├──────────────────────────────────────────────────────────────────────────────┤\n", - "│ │\n", - "│ CPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", - "│ │\n", - "│ │\n", - "└──────────────────────────────────────────────────────────────────────────────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,1,2,3\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 4,5,6,7\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, @@ -588,18 +602,18 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1.455235\n", - "0.7646729\n", - "0.50971293\n", - "0.378493\n", - "0.28089797\n" + "1.4929407\n", + "0.820176\n", + "0.5583741\n", + "0.41078538\n", + "0.2984159\n" ] } ], @@ -637,14 +651,14 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "7.89 ms ± 486 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "9.52 ms ± 142 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -673,7 +687,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -718,37 +732,37 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
┌───────┬───────┬───────┬───────┐\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "│       │       │       │       │\n",
-       "└───────┴───────┴───────┴───────┘\n",
+       "
                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       " CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
+       "                                    \n",
        "
\n" ], "text/plain": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, @@ -757,27 +771,33 @@ { "data": { "text/html": [ - "
┌───────────────────────┐\n",
-       "│        CPU 0,4        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 1,5        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 2,6        │\n",
-       "├───────────────────────┤\n",
-       "│        CPU 3,7        │\n",
-       "└───────────────────────┘\n",
+       "
                         \n",
+       "         CPU 0,4         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 1,5         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 2,6         \n",
+       "                         \n",
+       "                         \n",
+       "         CPU 3,7         \n",
+       "                         \n",
        "
\n" ], "text/plain": [ - "┌───────────────────────┐\n", - "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", - "├───────────────────────┤\n", - "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", - "└───────────────────────┘\n" + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, @@ -785,7 +805,7 @@ } ], "source": [ - "def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState:\n", + "def add_sharding_rule(vs: nnx.Variable) -> nnx.Variable:\n", " vs.sharding_rules = sharding_rules\n", " return vs\n", "\n", @@ -794,7 +814,7 @@ " model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))\n", " state = nnx.state(model)\n", " state = jax.tree.map(add_sharding_rule, state,\n", - " is_leaf=lambda x: isinstance(x, nnx.VariableState))\n", + " is_leaf=lambda x: isinstance(x, nnx.Variable))\n", " pspecs = nnx.get_partition_spec(state)\n", " sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n", " nnx.update(model, sharded_state)\n", @@ -854,7 +874,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index 7c3a73cf0..73f057487 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -344,7 +344,7 @@ class LogicalDotReluDot(nnx.Module): If you didn't provide all `sharding_rule` annotations in the model definition, you can write a few lines to add it to Flax’s [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). ```{code-cell} ipython3 -def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState: +def add_sharding_rule(vs: nnx.Variable) -> nnx.Variable: vs.sharding_rules = sharding_rules return vs @@ -353,7 +353,7 @@ def create_sharded_logical_model(): model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0)) state = nnx.state(model) state = jax.tree.map(add_sharding_rule, state, - is_leaf=lambda x: isinstance(x, nnx.VariableState)) + is_leaf=lambda x: isinstance(x, nnx.Variable)) pspecs = nnx.get_partition_spec(state) sharded_state = jax.lax.with_sharding_constraint(state, pspecs) nnx.update(model, sharded_state) diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst index 5cfb55a57..746c55ad1 100644 --- a/docs_nnx/guides/haiku_to_flax.rst +++ b/docs_nnx/guides/haiku_to_flax.rst @@ -380,12 +380,12 @@ The parameter structure is as follows: params { 'decoder': { - 'bias': VariableState(type=Param, value=(784,)), - 'kernel': VariableState(type=Param, value=(256, 784)) + 'bias': Param(value=(784,)), + 'kernel': Param(value=(256, 784)) }, 'encoder': { - 'bias': VariableState(type=Param, value=(256,)), - 'kernel': VariableState(type=Param, value=(784, 256)) + 'bias': Param(value=(256,)), + 'kernel': Param(value=(784, 256)) } } @@ -637,8 +637,8 @@ Now inspect the variable pytree on both sides: { 'blocks': { 'linear': { - 'bias': VariableState(type=Param, value=(5, 64)), - 'kernel': VariableState(type=Param, value=(5, 64, 64)) + 'bias': Param(value=(5, 64)), + 'kernel': Param(value=(5, 64, 64)) } } } diff --git a/docs_nnx/guides/linen_to_nnx.rst b/docs_nnx/guides/linen_to_nnx.rst index f60b4dea1..63e7fafb3 100644 --- a/docs_nnx/guides/linen_to_nnx.rst +++ b/docs_nnx/guides/linen_to_nnx.rst @@ -391,12 +391,12 @@ The variable structure is as follows: # params { 'decoder': { - 'bias': VariableState(type=Param, value=(784,)), - 'kernel': VariableState(type=Param, value=(256, 784)) + 'bias': Param(value=(784,)), + 'kernel': Param(value=(256, 784)) }, 'encoder': { - 'bias': VariableState(type=Param, value=(256,)), - 'kernel': VariableState(type=Param, value=(784, 256)) + 'bias': Param(value=(256,)), + 'kernel': Param(value=(784, 256)) } } @@ -647,8 +647,8 @@ Now inspect the variable pytree on both sides: { 'blocks': { 'linear': { - 'bias': VariableState(type=Param, value=(5, 64)), - 'kernel': VariableState(type=Param, value=(5, 64, 64)) + 'bias': Param(value=(5, 64)), + 'kernel': Param(value=(5, 64, 64)) } } } diff --git a/docs_nnx/guides/surgery.ipynb b/docs_nnx/guides/surgery.ipynb index edbe22975..f55a389ab 100644 --- a/docs_nnx/guides/surgery.ipynb +++ b/docs_nnx/guides/surgery.ipynb @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -133,22 +133,18 @@ "text": [ "State({\n", " 'linear1': {\n", - " 'bias': VariableState(\n", - " type=Param,\n", + " 'bias': Param( # 4 (16 B)\n", " value=ShapeDtypeStruct(shape=(4,), dtype=float32)\n", " ),\n", - " 'kernel': VariableState(\n", - " type=Param,\n", + " 'kernel': Param( # 16 (64 B)\n", " value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)\n", " )\n", " },\n", " 'linear2': {\n", - " 'bias': VariableState(\n", - " type=Param,\n", + " 'bias': Param( # 4 (16 B)\n", " value=ShapeDtypeStruct(shape=(4,), dtype=float32)\n", " ),\n", - " 'kernel': VariableState(\n", - " type=Param,\n", + " 'kernel': Param( # 16 (64 B)\n", " value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)\n", " )\n", " }\n", @@ -166,7 +162,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When you fill every `nnx.VariableState` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model." + "When you fill every `nnx.Variable` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model." ] }, { @@ -176,10 +172,10 @@ "outputs": [], "source": [ "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", - "abs_state['linear1']['kernel'].value = model.linear1.kernel\n", - "abs_state['linear1']['bias'].value = model.linear1.bias\n", - "abs_state['linear2']['kernel'].value = model.linear2.kernel\n", - "abs_state['linear2']['bias'].value = model.linear2.bias\n", + "abs_state['linear1']['kernel'].value = model.linear1.kernel.value\n", + "abs_state['linear1']['bias'].value = model.linear1.bias.value\n", + "abs_state['linear2']['kernel'].value = model.linear2.kernel.value\n", + "abs_state['linear2']['bias'].value = model.linear2.bias.value\n", "nnx.update(abs_model, abs_state)\n", "np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!" ] @@ -225,7 +221,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "This will throw error: : Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.\n" + "This will throw error: : User-provided restore item and on-disk value metadata tree structures do not match: {'layer1': Diff(lhs={'bias': {'value': ShapeDtypeStruct(shape=(4,), dtype=float32)}, 'kernel': {'value': ShapeDtypeStruct(shape=(4, 4), dtype=float32)}}, rhs=None), 'layer2': Diff(lhs={'bias': {'value': ShapeDtypeStruct(shape=(4,), dtype=float32)}, 'kernel': {'value': ShapeDtypeStruct(shape=(4, 4), dtype=float32)}}, rhs=None), 'linear1': Diff(lhs=None, rhs={'bias': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4,))}, 'kernel': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4, 4))}}), 'linear2': Diff(lhs=None, rhs={'bias': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4,))}, 'kernel': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4, 4))}})}\n" ] } ], @@ -256,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -264,22 +260,22 @@ "output_type": "stream", "text": [ "{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n", - " 'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n", - " [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n", - " [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n", - " [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}},\n", + " 'kernel': {'value': Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626],\n", + " [-0.46665004, 0.31773907, 0.38944173, -0.54608804],\n", + " [ 0.84378934, -0.93099 , -0.67658 , 0.0724705 ],\n", + " [-0.6101737 , 0.12972134, 0.877074 , 0.27292168]], dtype=float32)}},\n", " 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n", - " 'kernel': {'value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n", - " [ 0.41914317, 0.84359694, -0.47937787, -0.49135214],\n", - " [-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ],\n", - " [-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}}\n" + " 'kernel': {'value': Array([[ 0.67979455, 0.7079946 , -0.22166717, -0.4147039 ],\n", + " [ 0.20622818, 0.01024843, 0.31011865, -0.40491563],\n", + " [ 0.12478007, -0.7697264 , -0.48899388, 0.8853114 ],\n", + " [-0.5123713 , -0.23335123, 0.4374407 , 0.63321066]], dtype=float32)}}}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1251: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", " warnings.warn(\n" ] } @@ -379,8 +375,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Number of jax arrays in memory at start: 44\n", - "Number of jax arrays in memory at end: 46 (2 new created - lora_a and lora_b)\n" + "Number of JAX Arrays in memory at start: 44\n", + "Number of JAX Arrays in memory at end: 50 (2 new created - lora_a and lora_b)\n" ] } ], @@ -404,20 +400,6 @@ "print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'\n", " ' (2 new created - lora_a and lora_b)')" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs_nnx/guides/surgery.md b/docs_nnx/guides/surgery.md index 1df1ce596..252c35d0a 100644 --- a/docs_nnx/guides/surgery.md +++ b/docs_nnx/guides/surgery.md @@ -106,14 +106,14 @@ gdef, abs_state = nnx.split(abs_model) pprint(abs_state) ``` -When you fill every `nnx.VariableState` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model. +When you fill every `nnx.Variable` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model. ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) -abs_state['linear1']['kernel'].value = model.linear1.kernel -abs_state['linear1']['bias'].value = model.linear1.bias -abs_state['linear2']['kernel'].value = model.linear2.kernel -abs_state['linear2']['bias'].value = model.linear2.bias +abs_state['linear1']['kernel'].value = model.linear1.kernel.value +abs_state['linear1']['bias'].value = model.linear1.bias.value +abs_state['linear2']['kernel'].value = model.linear2.kernel.value +abs_state['linear2']['bias'].value = model.linear2.bias.value nnx.update(abs_model, abs_state) np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now! ``` @@ -233,11 +233,3 @@ good_model = partial_init(old_state, nnx.Rngs(42)) print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}' ' (2 new created - lora_a and lora_b)') ``` - -```{code-cell} ipython3 - -``` - -```{code-cell} ipython3 - -``` diff --git a/docs_nnx/nnx_glossary.rst b/docs_nnx/nnx_glossary.rst index 864c8a0ad..530290793 100644 --- a/docs_nnx/nnx_glossary.rst +++ b/docs_nnx/nnx_glossary.rst @@ -38,6 +38,3 @@ For additional terms, refer to the `JAX glossary ` residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. - - Variable state - :class:`nnx.VariableState ` is a purely functional `JAX pytree `__ of all the :term:`Variables` inside a :term:`Module`. Since it is pure, it can be an input or output of a `JAX transformation `__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split ` on the :class:`nnx.Module `. (Refer to :term:`splitting` and :term:`Module` to learn more.) diff --git a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py index 83dd397b0..75d04f9d8 100644 --- a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py +++ b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py @@ -87,13 +87,13 @@ def init_optimizer_state(variable: nnx.Variable): self.momentum: nnx.State = jax.tree.map( init_optimizer_state, self.params, - is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState), + is_leaf=lambda x: isinstance(x, nnx.Variable), ) self.decay = decay def update(self, grads: nnx.State): def update_fn( - params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState + params: nnx.Variable, momentum: SGDState, grad: nnx.Variable ): # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t) momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...] @@ -105,7 +105,7 @@ def update_fn( self.params, self.momentum, grads, - is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState), + is_leaf=lambda x: isinstance(x, nnx.Variable), ) @@ -118,12 +118,12 @@ def create_model(): state, nnx.get_named_sharding(state, mesh) ) - def get_named_shardings(path: tuple, value: nnx.VariableState): + def get_named_shardings(path: tuple, value: nnx.Variable): if path[0] == 'params': - return value.replace(NamedSharding(mesh, P(*value.sharding))) + return NamedSharding(mesh, P(*value.sharding)) elif path[0] == 'momentum': # currently the same as above but in general it could be different - return value.replace(NamedSharding(mesh, P(*value.sharding))) + return NamedSharding(mesh, P(*value.sharding)) else: raise ValueError(f'Unknown path: {path}') diff --git a/examples/nnx_toy_examples/mutable_array_demo.py b/examples/nnx_toy_examples/mutable_array_demo.py index 2f2bbc062..d499c9197 100644 --- a/examples/nnx_toy_examples/mutable_array_demo.py +++ b/examples/nnx_toy_examples/mutable_array_demo.py @@ -192,7 +192,7 @@ def __init__(self, params, lr: float, decay: float = 0.9): self.decay = decay def make_opt_state(x): - if isinstance(x, nnx.Variable | nnx.VariableState): + if isinstance(x, nnx.Variable): return OptState(jnp.zeros_like(x.value), **x.get_metadata()) else: return OptState(jnp.zeros_like(x)) @@ -201,7 +201,7 @@ def make_opt_state(x): jax.tree.map( make_opt_state, params, - is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState), + is_leaf=lambda x: isinstance(x, nnx.Variable), ) ) @@ -268,11 +268,11 @@ def test_step(model: Model, x, y): # minimalistic training loop -total_steps = 10_000 +total_steps = 2_000 for step, (x, y) in enumerate(dataset(32)): train_step(model, optimizer, rngs, x, y) - if step % 1000 == 0: + if step % 200 == 0: logs = test_step(eval_model, X, Y) print(f'step: {step}, loss: {logs["loss"]}') diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 1b820a9b5..2dc39d1b6 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -172,7 +172,6 @@ from .variablelib import Intermediate as Intermediate from .variablelib import Perturbation as Perturbation from .variablelib import Variable as Variable -from .variablelib import VariableState as VariableState from .variablelib import VariableMetadata as VariableMetadata from .variablelib import with_metadata as with_metadata from .variablelib import variable_type_from_name as variable_type_from_name diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index cf128751e..1a8f06ba2 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -374,27 +374,27 @@ def _get_variables(self) -> tp.Mapping: state = graph.state(self) _variables: dict = {} - variable_state: variablelib.VariableState - for path, variable_state in statelib.to_flat_state(state): - if issubclass(variable_state.type, rnglib.RngState): + variable: variablelib.Variable + for path, variable in statelib.to_flat_state(state): + if isinstance(variable, rnglib.RngState): # Don't return RNG states, since Linen doesn't have them. continue try: - collection = variablelib.variable_name_from_type(variable_state.type) + collection = variablelib.variable_name_from_type(type(variable)) except ValueError: - collection = variable_state.type.__name__ + collection = type(variable).__name__ if collection not in _variables: _variables[collection] = {} if ( - isinstance(variable_state, variablelib.VariableState) - and not variable_state._var_metadata + isinstance(variable, variablelib.Variable) + and not variable._var_metadata ): - leaf = variable_state.value + leaf = variable.value else: - leaf = bridge_variables.to_linen_var(variable_state) + leaf = bridge_variables.to_linen_var(variable) _variables[collection][path] = leaf diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 030440d09..235b506f5 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -43,7 +43,7 @@ def _variable_parents_count(t: type): class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): - """Default Flax metadata class for `nnx.VariableState`.""" + """Default Flax metadata class for `nnx.Variable`.""" var_type: type[variablelib.Variable[tp.Any]] = struct.field(pytree_node=False) value: Any = struct.field(pytree_node=True) @@ -65,15 +65,17 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': def get_partition_spec(self) -> jax.sharding.PartitionSpec: """Returns the ``Partitionspec`` for this partitioned value.""" - nnx_var = self.to_nnx_variable().to_state() - return spmd.get_partition_spec(nnx_var).raw_value + nnx_var = self.to_nnx_variable() + spec = spmd.get_partition_spec(nnx_var).raw_value + assert isinstance(spec, jax.sharding.PartitionSpec) + return spec def to_nnx_variable(self) -> variablelib.Variable: return self.var_type(self.value, **self.metadata) -def is_vanilla_variable(vs: variablelib.VariableState) -> bool: - """A variables state is vanilla if its metadata is essentially blank. +def is_vanilla_variable(vs: variablelib.Variable) -> bool: + """A variable is vanilla if its metadata is essentially blank. Returns False only if it has non-empty hooks or any non-built-in attribute. """ @@ -86,7 +88,7 @@ def is_vanilla_variable(vs: variablelib.VariableState) -> bool: return True -def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata: +def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: linen_type = metadata['linen_meta_type'] @@ -145,14 +147,11 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: - """Convert a dict of NNX variables (or variable states) to Linen-style variables.""" + """Convert a dict of NNX variables to Linen-style variables.""" linen_structured = {} for kp, v in traversals.flatten_mapping(nnx_attrs).items(): if isinstance(v, variablelib.Variable): col_name = variablelib.variable_name_from_type(type(v)) - v = to_linen_var(v.to_state()) - elif isinstance(v, variablelib.VariableState): - col_name = variablelib.variable_name_from_type(v.type) v = to_linen_var(v) elif isinstance(v, graph.GraphDef): col_name = 'nnx' # an nnx.GraphDef for some ToLinen submodule diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 51e79f9ef..764177652 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -252,7 +252,7 @@ class ToLinen(linen.Module): args: tp.Sequence = () kwargs: tp.Mapping[str, tp.Any] = FrozenDict({}) skip_rng: bool = False - metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = ( + metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = ( bv.to_linen_var ) @@ -310,7 +310,7 @@ def _update_variables(self, module): # group state by collection for path, leaf in nnx.to_flat_state(state): - type_ = leaf.type if isinstance(leaf, nnx.VariableState) else type(leaf) + type_ = leaf.type if isinstance(leaf, nnx.Variable) else type(leaf) collection = variablelib.variable_name_from_type( type_, allow_register=True ) @@ -323,7 +323,7 @@ def _update_variables(self, module): if self.is_mutable_collection(collection): def _to_linen_var(x): - if isinstance(x, nnx.VariableState): + if isinstance(x, nnx.Variable): if self.metadata_fn: return self.metadata_fn(x) else: @@ -334,7 +334,7 @@ def _to_linen_var(x): collection_state = jax.tree.map( _to_linen_var, collection_state, - is_leaf=lambda x: isinstance(x, nnx.VariableState), + is_leaf=lambda x: isinstance(x, nnx.Variable), ) for k, v in collection_state.items(): self.put_variable(collection, k, v) @@ -344,7 +344,7 @@ def to_linen( nnx_class: tp.Callable[..., Module], *args, metadata_fn: ( - tp.Callable[[variablelib.VariableState], tp.Any] | None + tp.Callable[[variablelib.Variable], tp.Any] | None ) = bv.to_linen_var, name: str | None = None, **kwargs, diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index 6e886eae5..f293b6cc0 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -121,9 +121,7 @@ class OfType: type: type def __call__(self, path: PathParts, x: tp.Any): - return isinstance(x, self.type) or ( - hasattr(x, 'type') and issubclass(x.type, self.type) - ) + return isinstance(x, self.type) def __repr__(self): return f'OfType({self.type!r})' diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index d062666f0..96ce9a013 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -29,7 +29,7 @@ DelayedAccessor, ) from flax.nnx.statelib import FlatState, State -from flax.nnx.variablelib import Variable, VariableState +from flax.nnx.variablelib import Variable from flax.typing import Key, PathParts, is_key_like import jax import numpy as np @@ -90,7 +90,6 @@ def __treescope_repr__(self, path, subtree_renderer): LeafType = tp.Union[ Variable, - VariableState, jax.Array, np.ndarray, variablelib.MutableArray, @@ -102,7 +101,7 @@ def __treescope_repr__(self, path, subtree_renderer): def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[LeafType]: - return isinstance(x, LeafType) or variablelib.is_mutable_array(x) # type: ignore[misc, arg-type] + return isinstance(x, LeafType) or variablelib.is_mutable_array(x) # type: ignore[misc, arg-type] class IndexMap(dict[Index, tp.Any]): @@ -510,6 +509,7 @@ class ArrayAttr: ARRAY_ATTR = ArrayAttr() + @dataclasses.dataclass(frozen=True, slots=True) class MutableArrayAttr: pass @@ -532,6 +532,7 @@ class NodeAttr: 'Static[tp.Any]', ] + # GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]] @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) @@ -592,19 +593,18 @@ def flatten( # type: ignore[invalid-annotation] *, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, -) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]]]: ... +) -> tuple[GraphDef[Node], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: tp.Literal[True], - return_variables: tp.Literal[True], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], - FlatState[Variable[tp.Any]], + FlatState[tp.Any], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] @@ -612,24 +612,11 @@ def flatten( # type: ignore[invalid-annotation] /, *, with_paths: tp.Literal[False], - return_variables: tp.Literal[True], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], - list[Variable[tp.Any]], -]: ... -@tp.overload -def flatten( # type: ignore[invalid-annotation] - node: Node, - /, - *, - return_variables: tp.Literal[True], - ref_index: RefMap | None = None, - ref_outer_index: RefMap | None = None, -) -> tuple[ - GraphDef[Node], - FlatState[Variable[tp.Any]], + list[tp.Any], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] @@ -641,19 +628,18 @@ def flatten( # type: ignore[invalid-annotation] ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], - FlatState[VariableState[tp.Any]] | list[tp.Any], + FlatState[tp.Any] | list[tp.Any], ]: ... def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: bool = True, - return_variables: bool = False, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], - FlatState[VariableState[tp.Any]] | FlatState[Variable[tp.Any]] | list[tp.Any], + FlatState[tp.Any] | list[tp.Any], ]: """Flattens a graph node into a (graphdef, state) pair. @@ -663,12 +649,12 @@ def flatten( # type: ignore[invalid-annotation] empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. with_paths: A boolean that indicates whether to return a FlatState object that includes - the paths to VariableState objects, or just a list of the Variable's inner values. + the paths, or just a list of the Variable's inner values. """ if ref_index is None: ref_index = RefMap() - leaves: list[LeafType] = [] + leaves: list[tp.Any] = [] path: list[Key] | None = [] if with_paths else None paths: list[PathParts] | None = [] if with_paths else None nodes: list[NodeDefType[tp.Any]] = [] @@ -684,7 +670,6 @@ def flatten( # type: ignore[invalid-annotation] attributes, leaves, paths, - return_variables, ) graphdef: GraphDef = GraphDef( nodes=nodes, attributes=attributes, num_leaves=len(leaves) @@ -704,9 +689,8 @@ def _graph_flatten( ref_outer_index: RefMap | None, nodes: list[NodeDefType[tp.Any]], attributes: list[tuple[Key, AttrType]], - leaves: list[LeafType], + leaves: list[tp.Any], paths: list[PathParts] | None, - return_variables: bool, ) -> None: is_pytree_node_ = type(node_impl) is PytreeNodeImpl @@ -756,13 +740,10 @@ def make_mutable_arraydef(value: variablelib.MutableArray): mutable_arraydef, inner_value = make_mutable_arraydef(inner_value) else: mutable_arraydef = None - if return_variables: - leaf = node - leaf.raw_value = inner_value - elif path is None: + if path is None: leaf = inner_value else: - leaf = node.to_state() # type: ignore[assignment] + leaf = node # type: ignore[assignment] leaf.raw_value = inner_value variabledef = VariableDef( @@ -781,7 +762,7 @@ def make_mutable_arraydef(value: variablelib.MutableArray): nodes.append(variabledef) return elif is_mutable_array: - mutable_arraydef, leaf = make_mutable_arraydef(node) # type: ignore[arg-type] + mutable_arraydef, leaf = make_mutable_arraydef(node) # type: ignore[arg-type] if not isinstance(leaf, Repeated): leaves.append(leaf) if path is not None: @@ -829,7 +810,6 @@ def make_mutable_arraydef(value: variablelib.MutableArray): attributes, leaves, paths, - return_variables, ) elif variablelib.is_mutable_array(value): attributes.append((key, MUTABLE_ARRAY_ATTR)) @@ -1188,7 +1168,7 @@ def get_mutable_array(mutable_arraydef: MutableArrayDef, leaf): ) elif type(leaf) in (NoUpdate, Repeated): raise ValueError( - 'Expected a MutableArrayOutput type but got ' f"'{leaf.value}.'" + f"Expected a MutableArrayOutput type but got '{leaf.value}.'" ) elif type(leaf) is MutableArrayOutput: mutable_array = variablelib.mutable_array(leaf.value) @@ -1216,7 +1196,8 @@ def get_mutable_array(mutable_arraydef: MutableArrayDef, leaf): else: value = next(leaves_iter) assert type(variabledef.mutable_arraydef) is MutableArrayDef - if isinstance(value, Variable | VariableState): + if isinstance(value, Variable): + value = value.copy() inner_value = value.raw_value mutable_array = get_mutable_array( variabledef.mutable_arraydef, inner_value @@ -1228,6 +1209,8 @@ def get_mutable_array(mutable_arraydef: MutableArrayDef, leaf): value = get_mutable_array(variabledef.mutable_arraydef, value) else: value = next(leaves_iter) + if isinstance(value, Variable): + value = value.copy() # when idxmap is present, check if the Varable exists there # and update existing variables if it does @@ -1241,11 +1224,6 @@ def get_mutable_array(mutable_arraydef: MutableArrayDef, leaf): if not isinstance(variable, Variable): raise ValueError(f'Expected a Variable type but got {type(variable)}.') elif isinstance(value, Variable): - raise ValueError( - f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' - f'Got {value!r}' - ) - elif isinstance(value, VariableState): variable.update_from_state(value) else: variable.raw_value = value @@ -1253,8 +1231,6 @@ def get_mutable_array(mutable_arraydef: MutableArrayDef, leaf): # variable reference does not exist outside, create a new one if isinstance(value, Variable): variable = value - elif isinstance(value, VariableState): - variable = value.to_variable() else: variable = variabledef.type.from_metadata( value, dict(variabledef.metadata) @@ -1280,7 +1256,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: for _ in range(nodedef.num_attributes): key, value = next(attribute_iter) if type(value) is Static: - children.append((key, value.value)) # type: ignore[attribute-error] + children.append((key, value.value)) # type: ignore[attribute-error] elif type(value) is MutableArrayAttr: mutable_arraydef = next(node_iter) assert ( @@ -1414,7 +1390,7 @@ def _graph_pop( id_to_index[id(value)] = len(id_to_index) node_impl.pop_key(node, name) if isinstance(value, Variable): - value = value.to_state() + value = value state[node_path] = value # type: ignore[index] # mypy is wrong here? break else: @@ -1424,8 +1400,8 @@ def _graph_pop( def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): def _update_variable(node: Variable, value): - if isinstance(value, VariableState): - # updated from VariableState + if isinstance(value, Variable): + # updated from Variable node.update_from_state(value) else: # updated from raw value @@ -1606,7 +1582,7 @@ def create_static_cache(x): # TODO(cgarciae): support Array attribute updates for graph nodes if is_graph_node(x) or isinstance(x, Variable): graphdef, flat_state = flatten( - x, with_paths=True, return_variables=True, ref_index=original_ref_index + x, with_paths=True, ref_index=original_ref_index ) paths = flat_state.paths variables = flat_state.leaves @@ -1696,7 +1672,7 @@ def flatten( # type: ignore[invalid-annotation] self, graph_node: A, /, - ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... + ) -> tuple[GraphDef[A], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] @@ -1704,7 +1680,7 @@ def flatten( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, /, - ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... + ) -> tuple[GraphDef[A], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] @@ -1716,8 +1692,8 @@ def flatten( # type: ignore[invalid-annotation] *filters: filterlib.Filter, ) -> tuple[ GraphDef[A], - FlatState[VariableState[tp.Any]], - tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], + FlatState[tp.Any], + tpe.Unpack[tuple[FlatState[tp.Any], ...]], ]: ... def flatten( # type: ignore[invalid-annotation] @@ -1727,8 +1703,8 @@ def flatten( # type: ignore[invalid-annotation] with_paths: bool = True, ) -> tuple[ GraphDef[A], - FlatState[VariableState[tp.Any]] | list[tp.Any], - tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], + FlatState[tp.Any] | list[tp.Any], + tpe.Unpack[tuple[FlatState[tp.Any], ...]], ]: if not with_paths and filters: raise ValueError('Cannot use filters with with_paths=False') @@ -1742,11 +1718,7 @@ def flatten( # type: ignore[invalid-annotation] ref_outer_index = ( ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None ) - flat_state: ( - FlatState[VariableState[tp.Any]] - | FlatState[Variable[tp.Any]] - | list[tp.Any] - ) + flat_state: FlatState[tp.Any] | list[tp.Any] leaves: list[tp.Any] if node in self.ref_index: # node is already in the ref_index, call flatten which will return a NodeRef @@ -1772,9 +1744,7 @@ def flatten( # type: ignore[invalid-annotation] if with_paths: paths = node_static_cache.paths - leaves = [ - variable.to_state() for variable in node_static_cache.variables - ] + leaves = node_static_cache.variables else: paths = None leaves = [ @@ -1831,9 +1801,9 @@ class MergeContext: def merge( # type: ignore[invalid-annotation] self, graphdef: GraphDef[A], - state: GraphState | VariableState, + state: GraphState, /, - *states: GraphState | VariableState, + *states: GraphState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None @@ -1910,7 +1880,7 @@ def unflatten( # type: ignore[invalid-annotation] f'leaves in the state, got {len(leaves)}' ) for variable, leaf in zip(static_cache_node.variables, leaves): - if type(leaf) is VariableState: + if isinstance(leaf, Variable): variable.update_from_state(leaf) else: variable.raw_value = leaf @@ -1945,7 +1915,7 @@ def unflatten( # type: ignore[invalid-annotation] @tp.overload @contextlib.contextmanager -def merge_context() -> tp.Generator[MergeContext, None, None]: ... # type: ignore[bad-return-type] +def merge_context() -> tp.Generator[MergeContext, None, None]: ... # type: ignore[bad-return-type] @tp.overload @contextlib.contextmanager def merge_context( @@ -2190,11 +2160,11 @@ def _split_state( @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, / -) -> tuple[GraphDef[A], GraphState | VariableState]: ... +) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, / -) -> tuple[GraphDef[A], GraphState | VariableState]: ... +) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, @@ -2204,15 +2174,15 @@ def split( # type: ignore[invalid-annotation] *filters: filterlib.Filter, ) -> tuple[ GraphDef[A], - GraphState | VariableState, - tpe.Unpack[tuple[GraphState | VariableState, ...]], + GraphState, + tpe.Unpack[tuple[GraphState, ...]], ]: ... def split( # type: ignore[invalid-annotation] node: A, *filters: filterlib.Filter ) -> tuple[ GraphDef[A], - GraphState | VariableState, - tpe.Unpack[tuple[GraphState | VariableState, ...]], + GraphState, + tpe.Unpack[tuple[GraphState, ...]], ]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef @@ -2236,22 +2206,18 @@ def split( # type: ignore[invalid-annotation] >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { - 'bias': VariableState( - type=Param, + 'bias': Param( value=(2,) ), - 'scale': VariableState( - type=Param, + 'scale': Param( value=(2,) ) }, 'linear': { - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) } @@ -2259,12 +2225,10 @@ def split( # type: ignore[invalid-annotation] >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { - 'mean': VariableState( - type=BatchStat, + 'mean': BatchStat( value=(2,) ), - 'var': VariableState( - type=BatchStat, + 'var': BatchStat( value=(2,) ) } @@ -2291,7 +2255,10 @@ def split( # type: ignore[invalid-annotation] def _to_nested_state( graphdef: GraphDef[A], flat_states: tp.Iterable[tp.Any] ) -> tuple[tp.Any, ...]: - if not graphdef.nodes or type(graphdef.nodes[0]) in (VariableDef, MutableArrayDef): + if not graphdef.nodes or type(graphdef.nodes[0]) in ( + VariableDef, + MutableArrayDef, + ): states = tuple( flat_state[0][1] if flat_state else State({}) for flat_state in flat_states @@ -2369,9 +2336,7 @@ def merge( # type: ignore[invalid-annotation] """ if isinstance(state, list): if len(states) != 0: - raise ValueError( - f'Only one state can be passed as a list.' - ) + raise ValueError(f'Only one state can be passed as a list.') _state = state else: _state = _merge_to_flat_state((state, *states)) @@ -2429,59 +2394,6 @@ def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]: yield path, value -@tp.overload -def variables(node, /) -> State[Key, Variable]: ... -@tp.overload -def variables(node, first: filterlib.Filter, /) -> State[Key, Variable]: ... -@tp.overload -def variables( - node, - first: filterlib.Filter, - second: filterlib.Filter, - /, - *filters: filterlib.Filter, -) -> tuple[State[Key, Variable], ...]: ... -def variables( - node, - *filters: filterlib.Filter, -) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]: - """Similar to :func:`state` but returns the current :class:`Variable` objects instead - of new :class:`VariableState` instances. - - Example:: - - >>> from flax import nnx - ... - >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - >>> params = nnx.variables(model, nnx.Param) - ... - >>> assert params['kernel'] is model.kernel - >>> assert params['bias'] is model.bias - - Args: - node: A graph node object. - *filters: One or more :class:`Variable` objects to filter by. - Returns: - One or more :class:`State` mappings containing the :class:`Variable` objects. - """ - num_filters = len(filters) - if num_filters == 0: - filters = (..., ...) - else: - filters = (*filters, ...) - - variables_iterable = _variables_generator(node) - flat_states = variablelib.split_flat_state( - variables_iterable, (*filters, ...) - ) - states = tuple( - statelib.from_flat_state(flat_state) for flat_state in flat_states - ) - if num_filters < 2: - return states[0] - return states - - @tp.overload def state(node, /) -> GraphState: ... @tp.overload @@ -2532,7 +2444,7 @@ def state( states: GraphState | tuple[GraphState, ...] if len(filters) == 0: - states = state # type: ignore[assignment] + states = state # type: ignore[assignment] elif len(filters) == 1: states = statelib.filter_state(state, filters[0]) else: @@ -2541,6 +2453,9 @@ def state( return states +variables = state + + def graphdef(node: tp.Any, /) -> GraphDef[tp.Any]: """Get the :class:`GraphDef` of the given graph node. @@ -2662,31 +2577,39 @@ def clone(node: Node) -> Node: graphdef, state = split(node) return merge(graphdef, state) -def find_duplicates(tree) -> tuple[str, str] | None: + +def find_duplicates( + tree, duplicate_fn: tp.Callable[[tuple[Key, ...], tp.Any], bool] | None = None +) -> tuple[str, str] | None: mutable_arrays: dict[int, str] = {} - paths_leaves = jax.tree.leaves_with_path(tree) + paths_leaves = jax.tree.leaves_with_path( + tree, is_leaf=lambda x: isinstance(x, Variable) + ) for path, x in paths_leaves: - m_array_id = id(x) - if m_array_id in mutable_arrays: - current_path_str = jax.tree_util.keystr(path) - previous_path_str = mutable_arrays[m_array_id] - return current_path_str, previous_path_str - mutable_arrays[m_array_id] = jax.tree_util.keystr(path) + nnx_path = jax_to_nnx_path(path) + if duplicate_fn is None or duplicate_fn(nnx_path, x): + m_array_id = id(x) + if m_array_id in mutable_arrays: + current_path_str = jax.tree_util.keystr(path) + previous_path_str = mutable_arrays[m_array_id] + return current_path_str, previous_path_str + mutable_arrays[m_array_id] = jax.tree_util.keystr(path) return None + def _mutable_like(path, x): return ( - isinstance(x, Variable | VariableState) and x.mutable + isinstance(x, Variable) and x.mutable ) or variablelib.is_mutable_array(x) def freeze( - node: A, - /, - *, - only: filterlib.Filter = _mutable_like, - allow_duplicates: bool = False, + node: A, + /, + *, + only: filterlib.Filter = _mutable_like, + allow_duplicates: bool = False, ) -> A: """Converts a structure of mutable arrays to regular arrays. @@ -2726,11 +2649,15 @@ def freeze( Returns: A structure with the frozen arrays. """ - if not allow_duplicates and (duplicate := find_duplicates(node)) is not None: + duplicate_fn = filterlib.to_predicate(only) + if ( + not allow_duplicates + and (duplicate := find_duplicates(node, duplicate_fn=duplicate_fn)) + is not None + ): current_path_str, previous_path_str = duplicate raise ValueError( - f"Found duplicate at path '{current_path_str}' " - f"and '{previous_path_str}'." + f"Found duplicate at path '{current_path_str}' and '{previous_path_str}'." ) graphdef, mutable_state, rest = split(node, only, ...) # type: ignore[misc] frozen_state = jax.tree.map(lambda x: x[...], mutable_state) @@ -2739,9 +2666,7 @@ def freeze( def _array_like(path, x): - return ( - isinstance(x, Variable | VariableState) and not x.mutable - ) or isinstance(x, jax.Array) + return (isinstance(x, Variable) and not x.mutable) or isinstance(x, jax.Array) def mutable(node: A, /, only: filterlib.Filter = _array_like) -> A: @@ -2783,11 +2708,11 @@ def mutable(node: A, /, only: filterlib.Filter = _array_like) -> A: Returns: A structure with the mutable arrays. """ - if (duplicate := find_duplicates(node)) is not None: + duplicate_fn = filterlib.to_predicate(only) + if (duplicate := find_duplicates(node, duplicate_fn=duplicate_fn)) is not None: current_path_str, previous_path_str = duplicate raise ValueError( - f"Found duplicate at path '{current_path_str}' " - f"and '{previous_path_str}'." + f"Found duplicate at path '{current_path_str}' and '{previous_path_str}'." ) graphdef, frozen_state, rest = split(node, only, ...) # type: ignore[misc] mutable_state = jax.tree.map(variablelib.mutable_array, frozen_state) @@ -2796,7 +2721,7 @@ def mutable(node: A, /, only: filterlib.Filter = _array_like) -> A: def pure(tree: A) -> A: - """Returns a new tree with all ``Variable`` and ``VariableState`` objects replaced with inner values. + """Returns a new tree with all ``Variable`` objects replaced with inner values. This can be used to remove Variable metadata when its is not needed for tasks like serialization or exporting. @@ -2811,12 +2736,10 @@ def pure(tree: A) -> A: >>> graphdef, state = nnx.split(model) >>> jax.tree.map(jnp.shape, state) State({ - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) }) @@ -2828,20 +2751,21 @@ def pure(tree: A) -> A: }) Args: - tree: A pytree potentially containing ``Variable`` and ``VariableState`` objects. + tree: A pytree potentially containing ``Variable`` objects. Returns: - A new pytree with all ``Variable`` and ``VariableState`` objects replaced with their + A new pytree with all ``Variable`` objects replaced with their inner values. """ + def _pure_fn(x): - if isinstance(x, Variable | VariableState): + if isinstance(x, Variable): return x.raw_value return x return jax.tree.map( _pure_fn, tree, - is_leaf=lambda x: isinstance(x, Variable | VariableState), + is_leaf=lambda x: isinstance(x, Variable), ) @@ -3047,6 +2971,7 @@ class IndexesPytreeDef(tp.NamedTuple): key_index: HashableMapping[Key, int] treedef: jax.tree_util.PyTreeDef + def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 73734fd65..491689a1d 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -301,12 +301,10 @@ class Linear(Module): >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ - 'bias': VariableState( - type=Param, + 'bias': Param( value=(4,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(3, 4) ) }) @@ -1106,8 +1104,7 @@ class Embed(Module): >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ - 'embedding': VariableState( # 15 (60 B) - type=Param, + 'embedding': Param( # 15 (60 B) value=Array([[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135], diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 837dff35a..a536a8542 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -202,20 +202,16 @@ class BatchNorm(Module): ... dtype=jnp.float32, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ - 'bias': VariableState( - type=Param, + 'bias': Param( value=(6,) ), - 'mean': VariableState( - type=BatchStat, + 'mean': BatchStat( value=(6,) ), - 'scale': VariableState( - type=Param, + 'scale': Param( value=(6,) ), - 'var': VariableState( - type=BatchStat, + 'var': BatchStat( value=(6,) ) }) @@ -223,7 +219,7 @@ class BatchNorm(Module): >>> # calculate batch norm on input and update batch statistics >>> layer.train() >>> y = layer(x) - >>> batch_stats1 = nnx.state(layer, nnx.BatchStat) + >>> batch_stats1 = nnx.clone(nnx.state(layer, nnx.BatchStat)) # keep a copy >>> y = layer(x) >>> batch_stats2 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all() @@ -402,12 +398,10 @@ class LayerNorm(Module): >>> nnx.state(layer) State({ - 'bias': VariableState( # 6 (24 B) - type=Param, + 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), - 'scale': VariableState( # 6 (24 B) - type=Param, + 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) @@ -540,8 +534,7 @@ class RMSNorm(Module): >>> nnx.state(layer) State({ - 'scale': VariableState( # 6 (24 B) - type=Param, + 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) @@ -666,12 +659,10 @@ class GroupNorm(Module): >>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ - 'bias': VariableState( # 6 (24 B) - type=Param, + 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), - 'scale': VariableState( # 6 (24 B) - type=Param, + 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) diff --git a/flax/nnx/object.py b/flax/nnx/object.py index 1e4fcd2e7..3184bfdea 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -36,7 +36,7 @@ visualization, ) from flax import config -from flax.nnx.variablelib import Variable, VariableState, is_mutable_array +from flax.nnx.variablelib import Variable, is_mutable_array from flax.typing import SizeBytes BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ @@ -598,7 +598,7 @@ def _graph_node_set_key(self, key: str, value: tp.Any): elif ( hasattr(self, key) and isinstance(variable := getattr(self, key), Variable) - and isinstance(value, VariableState) + and isinstance(value, Variable) ): variable.update_from_state(value) else: diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index dc4d2d036..175abcc88 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -31,6 +31,7 @@ F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) PARTITION_NAME = 'partition_name' + class HasSharding(tp.Protocol): sharding: tuple[str | None, ...] | None @@ -38,6 +39,7 @@ class HasSharding(tp.Protocol): def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]: return hasattr(x, 'sharding') and x.sharding is not None + def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A: axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata) @@ -49,24 +51,28 @@ def insert_field(fields, index, value): return tuple(iterable) def _add_axis(x: tp.Any): - if isinstance(x, variablelib.VariableState): - if _has_sharding(x) and x.sharding is not None: - x.sharding = insert_field(x.sharding, index, axis_name) + if isinstance(x, variablelib.Variable): + metadata = x.get_metadata() + if 'sharding' in metadata and metadata['sharding']: + sharding = metadata['sharding'] + x.sharding = insert_field(sharding, index, axis_name) for k, v in other_meta.items(): if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple): setattr(x, k, insert_field(t, index, v)) - assert isinstance(x, variablelib.VariableState) + assert isinstance(x, variablelib.Variable) x.add_axis(index, axis_name) return x return jax.tree.map( - _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.VariableState) + _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) ) -def remove_axis(tree: A, index: int, transform_metadata: tp.Mapping[tp.Any, tp.Any]) -> A: +def remove_axis( + tree: A, index: int, transform_metadata: tp.Mapping[tp.Any, tp.Any] +) -> A: axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata) def remove_field(fields, index, value): @@ -75,7 +81,7 @@ def remove_field(fields, index, value): return tuple(iterable) def _remove_axis(x: tp.Any): - if isinstance(x, variablelib.VariableState): + if isinstance(x, variablelib.Variable): if hasattr(x, 'sharding') and x.sharding is not None: x.sharding = remove_field(x.sharding, index, axis_name) @@ -89,12 +95,12 @@ def _remove_axis(x: tp.Any): return jax.tree.map( _remove_axis, tree, - is_leaf=lambda x: isinstance(x, variablelib.VariableState), + is_leaf=lambda x: isinstance(x, variablelib.Variable), ) def _get_partition_name_and_metadata( - transform_metadata: tp.Mapping[tp.Any, tp.Any] + transform_metadata: tp.Mapping[tp.Any, tp.Any], ) -> tuple[str, tp.Mapping[tp.Any, tp.Any]]: if PARTITION_NAME not in transform_metadata: raise ValueError( @@ -116,35 +122,31 @@ def _maybe_replicate(x): return None def f(x): - if isinstance(x, (variablelib.VariableState, variablelib.Variable)): - if hasattr(x, 'sharding') and x.sharding: - if core_spmd.get_logical_axis_rules() or hasattr(x, 'sharding_rules'): + if isinstance(x, variablelib.Variable): + metadata = x.get_metadata() + if 'sharding' in metadata and metadata['sharding']: + sharding = metadata['sharding'] + if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata: context_rules = core_spmd.get_logical_axis_rules() - local_rules = getattr(x, 'sharding_rules', ()) + local_rules = metadata.get('sharding_rules', ()) rules = core_spmd.composite_rules(context_rules, local_rules) return x.replace( - PartitionSpec(*core_spmd.from_sharding_rules(x.sharding, rules)) + PartitionSpec(*core_spmd.from_sharding_rules(sharding, rules)) ) - return x.replace(PartitionSpec(*x.sharding)) + return x.replace(PartitionSpec(*sharding)) else: return x.replace(_maybe_replicate(x.raw_value)) return _maybe_replicate(x) return jax.tree.map( - f, - tree, - is_leaf=lambda x: isinstance( - x, (variablelib.VariableState, variablelib.Variable) - ), + f, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) ) def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) - sharding = jax.tree.map( - lambda p: jax.sharding.NamedSharding(mesh, p), spec - ) + sharding = jax.tree.map(lambda p: jax.sharding.NamedSharding(mesh, p), spec) return sharding @@ -174,7 +176,7 @@ def _with_sharding_constraint( def _is_spec(x): - return x is None or ( + return x is None or isinstance(x, variablelib.Variable) or ( isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x) ) diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 4128e5fb1..7d833e193 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -219,9 +219,7 @@ def _flat_state_pytree_unflatten( class State(MutableMapping[K, V], reprlib.Representable): - """A pytree-like structure that contains a ``Mapping`` from hashable and - comparable keys to leaves. Leaves can be of any type but :class:`VariableState` - and :class:`Variable` are the most common. + """A pytree-like ``Mapping`` with hashable and comparable keys. """ def __init__( @@ -491,7 +489,7 @@ def from_flat_state( def to_pure_dict( state, extract_fn: ExtractValueFn | None = None ) -> dict[str, tp.Any]: - # Works for nnx.Variable and nnx.VariableState + # Works for nnx.Variable if extract_fn is None: extract_fn = lambda x: x.value if hasattr(x, 'value') else x flat_values = {k: extract_fn(x) for k, x in to_flat_state(state)} @@ -507,7 +505,7 @@ def try_convert_int(x): except ValueError: return x - # Works for nnx.Variable and nnx.VariableState + # Works for nnx.Variable if replace_fn is None: replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v current_flat = dict(to_flat_state(state)) @@ -749,7 +747,7 @@ def create_path_filters(state: State): flat_state = to_flat_state(state) value_paths: dict[tp.Any, set[PathParts]] = {} for path, value in flat_state: - if isinstance(value, (variablelib.Variable, variablelib.VariableState)): + if isinstance(value, variablelib.Variable): value = value.raw_value value_paths.setdefault(value, set()).add(path) return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py index f8af484c3..c3e10033e 100644 --- a/flax/nnx/summary.py +++ b/flax/nnx/summary.py @@ -334,9 +334,9 @@ def tabulate( _collect_stats((), obj, node_stats, object_types) _variable_types: set[type] = { nnx.RngState # type: ignore[misc] - if issubclass(variable_state.type, nnx.RngState) - else variable_state.type - for _, variable_state in nnx.to_flat_state(nnx.state(obj)) + if isinstance(leaf, nnx.RngState) + else type(leaf) + for _, leaf in nnx.to_flat_state(nnx.state(obj)) } variable_types: list[type] = sorted(_variable_types, key=lambda t: t.__name__) diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index ef423b1b8..98634b526 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -22,7 +22,7 @@ from flax import nnx from flax.nnx import filterlib from flax.nnx.object import Object -from flax.nnx.variablelib import Variable, VariableState +from flax.nnx.variablelib import Variable M = tp.TypeVar('M', bound=nnx.Module) @@ -44,73 +44,23 @@ class OptArray(OptState): class OptVariable(OptState): """Optimizer state for a Variable.""" - source_type: type[Variable] pass -def _wrap_optimizer_state(opt_state): - def wrap_optimizer_state_fn(x): - if isinstance(x, VariableState): - new_state = x.copy() - new_state.source_type = x.type - new_state.type = OptVariable - return new_state.to_variable() - else: - return OptArray(x) - - return jax.tree.map( - wrap_optimizer_state_fn, - opt_state, - is_leaf=lambda x: isinstance(x, VariableState), - ) - - -def _opt_state_variables_to_state(opt_state): - def optimizer_variable_to_state_fn(x): - if isinstance(x, OptVariable): - state = x.to_state() - state.type = x.source_type - del state.source_type - return state - elif isinstance(x, OptArray): - return x.value +def to_opt_state(tree): + def _to_opt_state(x): + if isinstance(x, Variable): + opt_state = OptVariable(x[...], **x.get_metadata()) # type: ignore else: - raise TypeError( - f'Unexpected type when converting optimizer state: {type(x)}' - ) - - return jax.tree.map( - optimizer_variable_to_state_fn, - opt_state, - is_leaf=lambda x: isinstance(x, nnx.Variable), - ) - + opt_state = OptArray(x) + return opt_state -def _update_opt_state(opt_state, updates): - def optimizer_update_variables(x, update): - if isinstance(x, OptVariable): - if not isinstance(update, VariableState): - raise TypeError( - f'Expected update to be VariableState, got {type(update)}' - ) - x.value = update.value - elif isinstance(x, OptArray): - if isinstance(update, VariableState): - raise TypeError( - f'Expected update to not to be a VariableState, got {update}' - ) - x.value = update - else: - raise TypeError( - f'Unexpected type when updating optimizer state: {type(x)}' - ) - - return jax.tree.map( - optimizer_update_variables, - opt_state, - updates, - is_leaf=lambda x: isinstance(x, nnx.Variable), + tree = jax.tree.map( + _to_opt_state, + tree, + is_leaf=lambda x: isinstance(x, Variable), ) + return tree class Optimizer(Object, tp.Generic[M]): @@ -203,7 +153,7 @@ def __init__( self.model = model self.tx = tx self.opt_state = nnx.data( - _wrap_optimizer_state(tx.init(nnx.state(model, wrt))) + to_opt_state(tx.init(nnx.state(model, wrt))) ) self.wrt = wrt @@ -229,17 +179,14 @@ def update(self, grads, **kwargs): >>> model = Model(rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(model)) State({ - 'custom_variable': VariableState( - type=CustomVariable, + 'custom_variable': CustomVariable( value=(1, 3) ), 'linear': { - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) } @@ -267,31 +214,20 @@ def update(self, grads, **kwargs): ``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``. """ params = nnx.state(self.model, self.wrt) - opt_state = _opt_state_variables_to_state(self.opt_state) + param_arrays = nnx.freeze(nnx.pure(params)) + grad_arrays = nnx.freeze(nnx.pure(grads)) + opt_state_arrays = nnx.freeze(nnx.pure(self.opt_state)) + kwargs_arrays = nnx.freeze(nnx.pure(kwargs)) - updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs) - new_params = optax.apply_updates(params, updates) - assert isinstance(new_params, nnx.State) + updates, new_opt_state = self.tx.update( + grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays + ) + new_params = optax.apply_updates(param_arrays, updates) - self.step.value += 1 nnx.update(self.model, new_params) - _update_opt_state(self.opt_state, new_opt_state) - - -def to_opt_state(tree): - def _to_opt_state(x): - if isinstance(x, Variable | VariableState): - opt_state = OptVariable(x[...], **x.get_metadata()) # type: ignore - else: - opt_state = OptArray(x) - return opt_state + nnx.update(self.opt_state, nnx.state(new_opt_state)) + self.step[...] += 1 - tree = jax.tree.map( - _to_opt_state, - tree, - is_leaf=lambda x: isinstance(x, Variable | VariableState), - ) - return tree class PytreeOptimizer(Object): @@ -381,6 +317,6 @@ def _update_variable(param, value): _update_variable, (params, self.opt_state), (new_params, new_opt_state), - is_leaf=lambda x: isinstance(x, Variable | VariableState), + is_leaf=lambda x: isinstance(x, Variable), ) self.step[...] += 1 \ No newline at end of file diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 0e67125b6..a2ca07c1d 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -64,7 +64,7 @@ class DiffState: class GradFn: f: tp.Callable[..., tp.Any] has_aux: bool - nondiff_states: deque[State | variablelib.VariableState | None] + nondiff_states: deque[State | None] def __post_init__(self): functools.update_wrapper(self, self.f) @@ -134,7 +134,7 @@ def _grad_general( def grad_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) del kwargs - nondiff_states: deque[State | variablelib.VariableState | None] = deque() + nondiff_states: deque[State | variablelib.Variable | None] = deque() def _grad_split_fn( ctx: graph.SplitContext, path, prefix: DiffState | None, value @@ -249,12 +249,10 @@ def grad( >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) }) @@ -275,8 +273,7 @@ def grad( >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) }) @@ -568,8 +565,8 @@ def state_to_node_states(is_differentiable: bool, x): if is_differentiable: if isinstance(x, jax.Array): return x - elif not isinstance(x, State | variablelib.VariableState): - raise ValueError(f'Expected State or VariableState, got {type(x)}') + elif not isinstance(x, State | variablelib.Variable): + raise ValueError(f'Expected State or Variable, got {type(x)}') return extract.NodeStates.from_states(x) return x @@ -577,7 +574,7 @@ def state_to_node_states(is_differentiable: bool, x): state_to_node_states, self.tree_node_args, tangent, - is_leaf=lambda x: isinstance(x, State | variablelib.VariableState), + is_leaf=lambda x: isinstance(x, State | variablelib.Variable), ) return pure_tangent @@ -765,12 +762,10 @@ def custom_vjp( ... >>> jax.tree.map(jnp.shape, grads) State({ - 'x': VariableState( - type=Param, + 'x': Param( value=() ), - 'y': VariableState( - type=Param, + 'y': Param( value=() ) }) @@ -813,8 +808,7 @@ def custom_vjp( ... >>> jax.tree.map(jnp.shape, grad) State({ - 'x': VariableState( - type=Param, + 'x': Param( value=() ) }) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index a2e6eed26..d954ca977 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -113,8 +113,8 @@ def __hash__(self): AxisFn = tp.Callable[ - [graph.GraphState | variablelib.VariableState, int, tp.Mapping], - graph.GraphState | variablelib.VariableState, + [graph.GraphState | variablelib.Variable, int, tp.Mapping], + graph.GraphState | variablelib.Variable, ] @@ -127,13 +127,13 @@ def _update_axes_fn(node_states): ): if isinstance(node_states.metadata, int): state = node_states.state - assert isinstance(state, State | variablelib.VariableState) + assert isinstance(state, State | variablelib.Variable) state = axis_fn(state, node_states.metadata, transform_metadata) return node_states.replace(states=(state,)) else: - states_out: list[graph.GraphState | variablelib.VariableState] = [] + states_out: list[graph.GraphState | variablelib.Variable] = [] for state, axis in zip(node_states.states, node_states.metadata.axes): - assert isinstance(state, graph.State | variablelib.VariableState) + assert isinstance(state, graph.State | variablelib.Variable) if isinstance(axis, int): state = axis_fn(state, axis, transform_metadata) states_out.append(state) @@ -697,8 +697,8 @@ def insert_index_mappings(x): def _scan_split_in( - carry_deque: PytreeDeque[list[State | variablelib.VariableState]], - broadcast_deque: PytreeDeque[list[State | variablelib.VariableState]], + carry_deque: PytreeDeque[list[State | variablelib.Variable]], + broadcast_deque: PytreeDeque[list[State | variablelib.Variable]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graph.SplitContext, @@ -707,9 +707,9 @@ def _scan_split_in( x, ): if graph.is_graph_node(x) or isinstance(x, variablelib.Variable): - vectorized_states: list[State | variablelib.VariableState] = [] - carry_states: list[State | variablelib.VariableState] = [] - broadcast_states: list[State | variablelib.VariableState] = [] + vectorized_states: list[State | variablelib.Variable] = [] + carry_states: list[State | variablelib.Variable] = [] + broadcast_states: list[State | variablelib.Variable] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) @@ -778,8 +778,8 @@ def _scan_split_in( def _scan_split_out( - carry_deque: PytreeDeque[list[State | variablelib.VariableState]], - broadcast_deque: PytreeDeque[list[State | variablelib.VariableState]], + carry_deque: PytreeDeque[list[State | variablelib.Variable]], + broadcast_deque: PytreeDeque[list[State | variablelib.Variable]], /, ctx: graph.SplitContext, path: extract.KeyPath, @@ -790,9 +790,9 @@ def _scan_split_out( is_input_arg = path[0].idx == 0 if graph.is_graph_node(x) or isinstance(x, variablelib.Variable): - vectorized_states: list[State | variablelib.VariableState] = [] - carry_states: list[State | variablelib.VariableState] = [] - broadcast_states: list[State | variablelib.VariableState] = [] + vectorized_states: list[State | variablelib.Variable] = [] + carry_states: list[State | variablelib.Variable] = [] + broadcast_states: list[State | variablelib.Variable] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) @@ -868,8 +868,8 @@ def _scan_split_out( def _scan_merge_in( - carry_deque: PytreeDeque[list[State | variablelib.VariableState]], - broadcast_deque: PytreeDeque[list[State | variablelib.VariableState]], + carry_deque: PytreeDeque[list[State]], + broadcast_deque: PytreeDeque[list[State]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graph.MergeContext, @@ -889,8 +889,8 @@ def _scan_merge_in( def _scan_merge_out( - carry_deque: PytreeDeque[list[State | variablelib.VariableState]], - broadcast_deque: PytreeDeque[list[State | variablelib.VariableState]], + carry_deque: PytreeDeque[list[State]], + broadcast_deque: PytreeDeque[list[State]], /, ctx: graph.MergeContext, path, @@ -901,13 +901,13 @@ def _scan_merge_out( is_input_arg = path[0].idx == 0 if isinstance(x, extract.NodeStates): - states: list[State | variablelib.VariableState] = [] + states: list[State] = [] if is_input_arg: carry_states = deque(carry_deque.popleft()) broadcast_states = deque(broadcast_deque.popleft()) else: - carry_states = deque[State | variablelib.VariableState]() - broadcast_states = deque[State | variablelib.VariableState]() + carry_states = deque[State]() + broadcast_states = deque[State]() if isinstance(prefix, StateAxes): vectorized_states = deque(x.states) for axis in prefix.axes: @@ -981,8 +981,8 @@ def __call__( self, carry: tuple[ tp.Any, # carry_arg - PytreeDeque[list[State | variablelib.VariableState]], # carry_deque - PytreeDeque[list[State | variablelib.VariableState]], # broadcast_deque + PytreeDeque[list[State]], # carry_deque + PytreeDeque[list[State]], # broadcast_deque PytreeDeque[Broadcasted], # broadcast_arrays ], pure_args: tuple[tp.Any, ...], @@ -1063,9 +1063,9 @@ def __call__( assert self.input_carry_argnum is None assert carry_arg_out is None - carry_deque_out = PytreeDeque[list[State | variablelib.VariableState]]() + carry_deque_out = PytreeDeque[list[State | variablelib.Variable]]() _broadcast_deque_out_tmp = PytreeDeque[ - list[State | variablelib.VariableState] + list[State | variablelib.Variable] ]() # discarded pure_args_out: tuple pure_args_out, pure_out = extract.to_tree( diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 67b809a78..f7cc958ec 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -90,12 +90,10 @@ class Variable(tp.Generic[A], reprlib.Representable): >>> jax.tree.map(jnp.shape, linear_variables) State({ 'linear': { - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) } @@ -104,8 +102,7 @@ class Variable(tp.Generic[A], reprlib.Representable): >>> custom_variable = nnx.state(model, CustomVariable) >>> jax.tree.map(jnp.shape, custom_variable) State({ - 'custom_variable': VariableState( - type=CustomVariable, + 'custom_variable': CustomVariable( value=(1, 3) ) }) @@ -113,17 +110,14 @@ class Variable(tp.Generic[A], reprlib.Representable): >>> variables = nnx.state(model) >>> jax.tree.map(jnp.shape, variables) State({ - 'custom_variable': VariableState( - type=CustomVariable, + 'custom_variable': CustomVariable( value=(1, 3) ), 'linear': { - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) } @@ -219,9 +213,11 @@ def __delattr__(self, name: str): else: del self._var_metadata[name] - @classmethod - def state(cls, value: A, **metadata) -> VariableState[A]: - return cls(value, **metadata).to_state() + # NOTE(cgarciae): adding this for backward compatibility with VariableState + @property + def type(self): + """The type of the variable.""" + return type(self) @property def mutable(self) -> bool: @@ -250,7 +246,7 @@ def copy_from(self, other: Variable[A]) -> None: self._var_metadata.clear() self._var_metadata.update(other.get_metadata()) - def update_from_state(self, variable_state: VariableState[A]): + def update_from_state(self, variable_state: Variable[A]): object.__setattr__(self, 'raw_value', variable_state.raw_value) object.__setattr__( self, '_var_metadata', variable_state._var_metadata.copy() @@ -342,13 +338,12 @@ def from_metadata(cls, value: A, attributes: dict[str, tp.Any]): def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) - object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, '_trace_state', tracers.TraceState()) object.__setattr__(obj, 'raw_value', self.raw_value) object.__setattr__(obj, '_var_metadata', self.get_metadata().copy()) return obj - def to_state(self: Variable[A]) -> VariableState[A]: - return VariableState(type(self), self.raw_value, **self._var_metadata) + to_state = copy def __nnx_repr__(self): stats = SizeBytes.from_any(self.raw_value) @@ -806,12 +801,10 @@ class Param(Variable[A]): >>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'kernel': VariableState( - type=Param, + 'kernel': Param( value=(2, 3) ) }) @@ -833,20 +826,16 @@ class BatchStat(Variable[A]): >>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ - 'bias': VariableState( - type=Param, + 'bias': Param( value=(3,) ), - 'mean': VariableState( - type=BatchStat, + 'mean': BatchStat( value=(3,) ), - 'scale': VariableState( - type=Param, + 'scale': Param( value=(3,) ), - 'var': VariableState( - type=BatchStat, + 'var': BatchStat( value=(3,) ) }) @@ -873,16 +862,13 @@ class Cache(Variable[A]): >>> layer.init_cache((1, 3)) >>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache)) State({ - 'cache_index': VariableState( - type=Cache, + 'cache_index': Cache( value=() ), - 'cached_key': VariableState( - type=Cache, + 'cached_key': Cache( value=(1, 2, 3) ), - 'cached_value': VariableState( - type=Cache, + 'cached_value': Cache( value=(1, 2, 3) ) }) @@ -913,8 +899,7 @@ class Intermediate(Variable[A]): >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate)) State({ - 'i': VariableState( - type=Intermediate, + 'i': Intermediate( value=((1, 3),) ) }) @@ -945,8 +930,7 @@ class Perturbation(Intermediate[A]): >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation)) State({ - 'i': VariableState( - type=Perturbation, + 'i': Perturbation( value=(1, 3) ) }) @@ -955,163 +939,6 @@ class Perturbation(Intermediate[A]): pass -class VariableState(tp.Generic[A], reprlib.Representable): - __slots__ = ('type', 'value', '_var_metadata') - type: type[Variable[A]] - value: A - _var_metadata: dict[str, tp.Any] - - def __init__( - self, - type: type[Variable[A]], # type: ignore [valid-type] - value: A, - **metadata, - ): - object.__setattr__(self, 'type', type) - object.__setattr__(self, 'value', value) - object.__setattr__(self, '_var_metadata', metadata) - - @property - def raw_value(self) -> A: - return object.__getattribute__(self, 'value') - - @raw_value.setter - def raw_value(self, value: A) -> None: - object.__setattr__(self, 'value', value) - - @property - def mutable(self) -> bool: - if is_mutable_array(self.raw_value): - return True - elif isinstance(self.raw_value, jax.Array): - return False - else: - raise ValueError( - f'mutable is only supported for jax.Array and MutableArray, ' - f'got {type(self.raw_value).__name__}' - ) - - def __getattribute__(self, name: str) -> None: - if name == 'value': - value = object.__getattribute__(self, 'value') - if is_mutable_array(value): - value = value[...] - return value - return object.__getattribute__(self, name) - - def __getattr__(self, name: str) -> None: - var_metadata = object.__getattribute__(self, '_var_metadata') - if name not in var_metadata: - raise AttributeError(f"'VariableState' object has no attribute '{name}'") - return var_metadata[name] - - def __setattr__(self, name: str, value: Any) -> None: - if name in ('type', 'value', '_var_metadata', 'raw_value'): - object.__setattr__(self, name, value) - else: - self._var_metadata[name] = value - - def __delattr__(self, name: str) -> None: - if name in ('type', 'value', '_var_metadata', 'raw_value'): - object.__delattr__(self, name) - else: - del self._var_metadata[name] - - def __getitem__(self, key: Any) -> jax.Array: - return self.raw_value[key] # type: ignore - - def __setitem__(self, key: Any, value: Any) -> None: - self.raw_value[key] = value # type: ignore - - def __nnx_repr__(self): - stats = SizeBytes.from_any(self.raw_value) - if stats: - comment = f' # {stats}' - else: - comment = '' - - yield reprlib.Object(type=type(self), comment=comment) - yield reprlib.Attr('type', self.type) - yield reprlib.Attr('value', self.raw_value) - - for name, value in self._var_metadata.items(): - yield reprlib.Attr(name, value) - - def __treescope_repr__(self, path, subtree_renderer): - size_bytes = SizeBytes.from_any(self.raw_value) - if size_bytes: - stats_repr = f' # {size_bytes}' - first_line_annotation = treescope.rendering_parts.comment_color( - treescope.rendering_parts.text(f'{stats_repr}') - ) - else: - first_line_annotation = None - children = {'type': self.type, 'value': self.value, **self._var_metadata} - return visualization.render_object_constructor( - object_type=type(self), - attributes=children, - path=path, - subtree_renderer=subtree_renderer, - first_line_annotation=first_line_annotation, - ) - - def replace(self, value: B) -> VariableState[B]: - return VariableState(self.type, value, **self.get_metadata()) - - def to_variable(self) -> Variable[A]: - # we use object.__new__ to avoid calling __init__ and bypass the - # __init__ logic which should not be called twice - variable = object.__new__(self.type) - object.__setattr__(variable, '_trace_state', tracers.TraceState()) - object.__setattr__(variable, 'raw_value', self.raw_value) - object.__setattr__(variable, '_var_metadata', self.get_metadata().copy()) - return variable - - def copy(self: VariableState[A]) -> VariableState[A]: - return jax.tree.map(lambda x: x, self) - - def get_metadata(self) -> dict[str, tp.Any]: - return self._var_metadata - - def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - if 'on_add_axis' in self._var_metadata: - self._var_metadata['on_add_axis'](self, axis_index, axis_name) - - def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - if 'on_remove_axis' in self._var_metadata: - self._var_metadata['on_remove_axis'](self, axis_index, axis_name) - -GraphVariableState = VariableState[VariableState[tp.Any]] - -def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): - metadata = tuple(x.get_metadata().items()) - if with_keys: - node = (jtu.GetAttrKey('value'), x.raw_value) - else: - node = x.raw_value - - return (node,), (x.type, metadata) - - -def _variable_state_unflatten( - static: tuple[type[Variable[A]], tuple[tuple[str, tp.Any], ...]], - children: tuple[A], -) -> VariableState[A]: - return VariableState( - type=static[0], - value=children[0], - **dict(static[1]), - ) - - -jtu.register_pytree_with_keys( - VariableState, - partial(_variable_state_flatten, with_keys=True), # type: ignore - _variable_state_unflatten, # type: ignore - flatten_func=partial(_variable_state_flatten, with_keys=False), # type: ignore -) - - def with_metadata( initializer: F, set_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), @@ -1184,13 +1011,13 @@ def wrapper(*args): def split_flat_state( - flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]], + flat_state: tp.Iterable[tuple[PathParts, Variable]], filters: tuple[filterlib.Filter, ...], -) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]: +) -> tuple[list[tuple[PathParts, Variable]], ...]: predicates = filterlib.filters_to_predicates(filters) # we have n + 1 states, where n is the number of predicates # the last state is for values that don't match any predicate - flat_states: tuple[list[tuple[PathParts, Variable | VariableState]], ...] = ( + flat_states: tuple[list[tuple[PathParts, Variable]], ...] = ( tuple([] for _ in predicates) ) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 802936845..65c2eb9a2 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -168,29 +168,13 @@ def test_unflatten_return_variables(self): g = List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.graph.flatten( - g, with_paths=False, return_variables=True + g, with_paths=True ) self.assertLen(state, 2) - self.assertIsInstance(state, list) - self.assertIsInstance(state[0], nnx.Param) - self.assertIsInstance(state[1], nnx.Param) - - def test_clone_with_same_variables(self): - a = Dict({'a': 1, 'b': nnx.Param(2)}) - g = List([a, 3, a, nnx.Param(4)]) - - graphdef, state = nnx.graph.flatten( - g, with_paths=False, return_variables=True - ) - - g2 = nnx.graph.unflatten(graphdef, state) - - self.assertIsNot(g, g2) - self.assertIsNot(g[0], g2[0]) - self.assertIsNot(g[2], g2[2]) - self.assertIs(g[0]['b'], g2[0]['b']) - self.assertIs(g[3], g2[3]) + self.assertIsInstance(state, nnx.graph.FlatState) + self.assertIsInstance(state[0][1], nnx.Param) + self.assertIsInstance(state[1][1], nnx.Param) def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(2)} @@ -297,7 +281,7 @@ def __call__(self, x): assert y.shape == (2, 10) - def test_state_variables_not_shared_with_graph(self): + def test_state_variables_shared_with_graph(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) @@ -306,18 +290,18 @@ def __init__(self): graphdef, state = nnx.split(m) assert isinstance(m.a, nnx.Param) - assert issubclass(state['a'].type, nnx.Param) - assert m.a is not state['a'] + assert isinstance(state['a'], nnx.Param) + assert m.a is state['a'] assert m.a.value == state['a'].value m2 = nnx.merge(graphdef, state) assert isinstance(m2.a, nnx.Param) - assert issubclass(state['a'].type, nnx.Param) + assert isinstance(state['a'], nnx.Param) assert m2.a is not state['a'] assert m2.a.value == state['a'].value - def test_shared_state_variables_not_shared_with_graph(self): + def test_shared_state_variables_shared_with_graph(self): class Foo(nnx.Module): def __init__(self): p = nnx.Param(1) @@ -329,10 +313,10 @@ def __init__(self): assert isinstance(m.a, nnx.Param) assert isinstance(m.b, nnx.Param) - assert issubclass(state['a'].type, nnx.Param) + assert isinstance(state['a'], nnx.Param) assert 'b' not in state - assert m.a is not state['a'] - assert m.b is not state['a'] + assert m.a is state['a'] + assert m.b is state['a'] assert m.a.value == state['a'].value assert m.b.value == state['a'].value @@ -340,7 +324,7 @@ def __init__(self): assert isinstance(m2.a, nnx.Param) assert isinstance(m2.b, nnx.Param) - assert issubclass(state['a'].type, nnx.Param) + assert isinstance(state['a'], nnx.Param) assert m2.a is not state['a'] assert m2.b is not state['a'] assert m2.a.value == state['a'].value @@ -933,7 +917,7 @@ def test_split_variable(self): graphdef, state = nnx.split(v) self.assertIsInstance(graphdef.nodes[0], nnx.graph.VariableDef) - self.assertIsInstance(state, nnx.VariableState) + self.assertIsInstance(state, nnx.Variable) v2 = nnx.merge(graphdef, state) self.assertIsInstance(v2, nnx.Param) @@ -945,7 +929,7 @@ def test_split_filter_variable(self): ) self.assertIsInstance(graphdef.nodes[0], nnx.graph.VariableDef) - self.assertIsInstance(params, nnx.VariableState) + self.assertIsInstance(params, nnx.Variable) self.assertIsInstance(batch_stats, nnx.State) self.assertEmpty(batch_stats) self.assertIsInstance(rest, nnx.State) @@ -959,7 +943,7 @@ def test_split_update_variable(self): graphdef, state = nnx.split(v) self.assertIsInstance(graphdef.nodes[0], nnx.graph.VariableDef) - self.assertIsInstance(state, nnx.VariableState) + self.assertIsInstance(state, nnx.Variable) state.value = 2 nnx.update(v, state) @@ -973,7 +957,7 @@ def test_split_update_filter_variable(self): ) self.assertIsInstance(graphdef.nodes[0], nnx.graph.VariableDef) - self.assertIsInstance(params, nnx.VariableState) + self.assertIsInstance(params, nnx.Variable) self.assertIsInstance(batch_stats, nnx.State) self.assertEmpty(batch_stats) self.assertIsInstance(rest, nnx.State) @@ -1098,7 +1082,6 @@ def test_threading(self): x = SimpleModule() class MyThread(Thread): - def run(self) -> None: nnx.graph.split(x) diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 7722bf91c..a8001e07a 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -281,7 +281,7 @@ def __call__(self, x): intermediates = nnx.pop(m, nnx.Intermediate) - assert issubclass(intermediates['y'].type, nnx.Intermediate) + assert isinstance(intermediates['y'], nnx.Intermediate) assert intermediates['y'].value == (3, 11) assert not hasattr(m, 'y') @@ -622,13 +622,13 @@ def __call__(self, x): expected_total = sum(int(np.prod(x.value.shape)) for x in leaves) expected_total_params = sum( - int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.Param + int(np.prod(x.value.shape)) for x in leaves if isinstance(x, nnx.Param) ) expected_total_batch_stats = sum( - int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.BatchStat + int(np.prod(x.value.shape)) for x in leaves if isinstance(x, nnx.BatchStat) ) expected_total_rng_states = sum( - int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.RngState + int(np.prod(x.value.shape)) for x in leaves if isinstance(x, nnx.RngState) ) foo_repr = repr(obj).replace(',', '').splitlines() @@ -664,13 +664,13 @@ class Foo(nnx.Module): assert len(state) == 4 assert state['b'].value == 2 - assert state['b'].type == nnx.Variable + assert isinstance(state['b'], nnx.Variable) assert state['c'].value == 3 - assert state['c'].type == nnx.Param + assert isinstance(state['c'], nnx.Param) assert state['d'].value == 4 - assert state['d'].type == nnx.Variable + assert isinstance(state['d'], nnx.Variable) assert state['e'].value == 5 - assert state['e'].type == nnx.BatchStat + assert isinstance(state['e'], nnx.BatchStat) def test_post_init(self): @@ -708,7 +708,7 @@ def __call__(self, x, *, rngs: nnx.Rngs): graphdef, states = nnx.split(foo) assert isinstance(states, nnx.State) - assert issubclass(states['w'].type, nnx.Param) + assert isinstance(states['w'], nnx.Param) y, _updates = graphdef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) @@ -732,8 +732,8 @@ def __call__(self, x, *, rngs: nnx.Rngs): assert isinstance(graphdef.nodes[0], nnx.graph.NodeDef | nnx.graph.NodeRef) assert isinstance(state, nnx.State) - assert issubclass(state['w'].type, nnx.Param) - assert issubclass(state['c'].type, nnx.Variable) + assert isinstance(state['w'], nnx.Param) + assert isinstance(state['c'], nnx.Variable) y, (graphdef, state) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index 26709e335..2b526fa08 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -125,8 +125,8 @@ def test_keep_rngs(self, keep_rngs): assert module.rngs is None if keep_rngs: _, _, nondiff = nnx.split(module, nnx.Param, ...) - assert nondiff['rngs']['count'].type is nnx.RngCount - assert nondiff['rngs']['key'].type is nnx.RngKey + assert isinstance(nondiff['rngs']['count'], nnx.RngCount) + assert isinstance(nondiff['rngs']['key'], nnx.RngKey) else: nnx.split(module, nnx.Param) diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index 05aac74c7..e49110b76 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -241,7 +241,7 @@ def test_wrt_update(self, variable): rngs=nnx.Rngs(1), ) state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) - prev_variables, prev_other_variables = nnx.state(model, variable, ...) + prev_variables, prev_other_variables = nnx.clone(nnx.state(model, variable, ...)) x = jnp.ones((1, 4)) y = jnp.ones((1, 10)) @@ -288,7 +288,7 @@ def test_wrt_update_linesearch(self, variable): rngs=nnx.Rngs(1), ) state = nnx.Optimizer(model, optax.lbfgs(), wrt=variable) - prev_variables, prev_other_variables = nnx.state(model, variable, ...) + prev_variables, prev_other_variables = nnx.clone(nnx.state(model, variable, ...)) x = jnp.ones((1, 4)) y = jnp.ones((1, 10)) diff --git a/tests/nnx/state_test.py b/tests/nnx/state_test.py index e1f78d1d5..8e8441373 100644 --- a/tests/nnx/state_test.py +++ b/tests/nnx/state_test.py @@ -21,19 +21,19 @@ class StateTest(absltest.TestCase): def test_create_state(self): - state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) + state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) assert state['a'].value == 1 assert state['b']['c'].value == 2 def test_get_attr(self): - state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) + state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) assert state.a.value == 1 assert state.b.c.value == 2 def test_set_attr(self): - state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) + state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) state.a.value = 3 state.b.c.value = 4 @@ -42,24 +42,24 @@ def test_set_attr(self): assert state['b']['c'].value == 4 def test_set_attr_variables(self): - state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) + state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) state.a.value = 3 state.b.c.value = 4 - assert issubclass(state.a.type, nnx.Param) + assert isinstance(state.a, nnx.Param) assert state.a.value == 3 - assert issubclass(state.b.c.type, nnx.Param) + assert isinstance(state.b.c, nnx.Param) assert state.b.c.value == 4 def test_add_nested_attr(self): - state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) - state.b.d = nnx.Param.state(5) + state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) + state.b.d = nnx.Param(5) assert state['b']['d'].value == 5 def test_delete_nested_attr(self): - state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) + state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) del state['b']['c'] assert 'c' not in state['b'] @@ -86,9 +86,9 @@ def test_pure_dict(self): assert isinstance(pure_dict['bias'], jax.Array) nnx.replace_by_pure_dict(state, jax.tree.map(jnp.zeros_like, pure_dict)) assert isinstance(state, nnx.State) - assert isinstance(state['kernel'], nnx.VariableState) + assert isinstance(state['kernel'], nnx.Variable) assert jnp.array_equal(state['kernel'].value, jnp.zeros((4, 5))) - assert state['kernel'].type == nnx.Param + assert type(state['kernel']) == nnx.Param nnx.update(module, state) assert jnp.array_equal(module(jnp.ones((3, 4))), jnp.zeros((3, 5))) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index fb40c14b0..d32ed9cea 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -406,16 +406,16 @@ def test_cache_args(self): @nnx.jit def f(cached_m: nnx.Linear, m: nnx.Linear): self.assertIsNot(cached_m, m) - self.assertIs(cached_m.kernel, m.kernel) - self.assertIs(cached_m.bias, m.bias) + self.assertIsNot(cached_m.kernel, m.kernel) + self.assertIsNot(cached_m.bias, m.bias) return cached_m cached_f = nnx.cached_partial(f, m) cached_m = cached_f(m) self.assertIsNot(m, cached_m) - self.assertIs(m.kernel, cached_m.kernel) - self.assertIs(m.bias, cached_m.bias) + self.assertIsNot(m.kernel, cached_m.kernel) + self.assertIsNot(m.bias, cached_m.bias) # test that cached m is reused cached_m2 = cached_f(m) @@ -451,6 +451,11 @@ def f(m: nnx.Linear, x): y = compiled(m, x) self.assertEqual(m.count.value, 2) +class TestEvalShape(absltest.TestCase): + def test_eval_shape(self): + abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) + self.assertIsInstance(abs_model, nnx.Linear) + self.assertIsInstance(abs_model.kernel.value, jax.ShapeDtypeStruct) class TestShardMap(absltest.TestCase): def test_basic_shardmap(self): diff --git a/tests/nnx/variable_test.py b/tests/nnx/variable_test.py index 2881475bc..270a7d086 100644 --- a/tests/nnx/variable_test.py +++ b/tests/nnx/variable_test.py @@ -22,9 +22,9 @@ A = tp.TypeVar('A') -class TestVariableState(absltest.TestCase): +class TestVariable(absltest.TestCase): def test_pytree(self): - r1 = nnx.VariableState(nnx.Param, 1) + r1 = nnx.Param(1) self.assertEqual(r1.value, 1) r2 = jax.tree.map(lambda x: x + 1, r1)