Replies: 5 comments 2 replies
-
|
Wait...
|
Beta Was this translation helpful? Give feedback.
-
|
I am really confused too. It seems now that |
Beta Was this translation helpful? Give feedback.
-
|
In addition, I have another question. The release notes say that for some simple functions, we no longer need to use merge and split in the code. How is the efficiency in this case? Will there still be a relatively large overhead as mentioned in the earlier performance considerations? (I noticed that when we use jax.jit to decorate a function, the speed does seem to be noticeably faster than with nnx.jit.) |
Beta Was this translation helpful? Give feedback.
-
|
Split and merge paradigm is still there and wont be deprecated. Here is an example of usage: flax/examples/nnx_toy_examples/mutable_array_basic.py Lines 62 to 71 in 2bf5748 In a special case when model contains only parameters, one can simplify the code and remove split/merge calls. |
Beta Was this translation helpful? Give feedback.
-
|
I'm sure the incoming docs will cover everything, but I'm curious about the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
v0.11.0 - Pytrees, MutableArrays, and more!
This version of Flax introduces some changes to improve interop with native JAX and adds support for the new
jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:Rngsin standard layers: all standard layers no longer hold a shared reference to therngsobject given in the constructor, instead they now keep afork-ed copy of theRngsorRngStreamobjects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.modelto avoid reference sharing, instead themodelmust be provided as the first argument toupdate.splitandmergewhen interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects withjax.tree.*APIs.Checkout the full NNX 0.10 to NNX 0.11 migration guide.
In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!
What's Changed
.typeusage by @vfdev-5 in Fix failing CI jobs: trailing whitespace, deprecated.typeusage #4823.valueto[...]in modules_test.py by @lukeyeh in refactor: move usages of.valueto[...]in modules_test.py #4815transforms_test.pyfrom.valueto[...]by @lukeyeh in Migratetransforms_test.pyfrom.valueto[...]#4841New Contributors
.valueto[...]in modules_test.py #4815Full Changelog: v0.10.7...v0.11.0
This discussion was created from the release v0.11.0.
Beta Was this translation helpful? Give feedback.
All reactions