Skip to content

Conversation

@cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented May 8, 2025

What does this PR do?

  • Adds nnx.mutable_array and nnx.MutableArray, re-exporting them from jax.experimental.
  • Uses [...] notation in some places instead of .value.
  • Adds __data__ to most / all public Object-derived types.
  • __data__'s type is now tuple[str, ...] | Literal['all', 'auto']

@cgarciae cgarciae force-pushed the mutable-array-p2 branch 5 times, most recently from 75f2a1e to 6ef64e3 Compare May 9, 2025 20:41
@cgarciae cgarciae force-pushed the mutable-array-p2 branch 14 times, most recently from 3c3ed21 to 82564df Compare May 13, 2025 18:33
@copybara-service copybara-service bot force-pushed the test_755993917 branch 3 times, most recently from 4dc58b7 to 90b3744 Compare May 14, 2025 00:18
@cgarciae cgarciae force-pushed the mutable-array-p2 branch from 82564df to 3b81664 Compare May 14, 2025 00:30
@cgarciae cgarciae marked this pull request as ready for review May 14, 2025 00:37
Base automatically changed from test_755993917 to main May 14, 2025 00:48
@cgarciae cgarciae force-pushed the mutable-array-p2 branch 2 times, most recently from bf86437 to e264671 Compare May 14, 2025 23:17
@cgarciae cgarciae force-pushed the mutable-array-p2 branch 2 times, most recently from e5ae64a to e571f5a Compare May 15, 2025 22:37
@cgarciae cgarciae force-pushed the mutable-array-p2 branch 2 times, most recently from 12da3fc to af4a0ad Compare May 15, 2025 23:09
@cgarciae cgarciae force-pushed the mutable-array-p2 branch from af4a0ad to f50b4d9 Compare May 15, 2025 23:22
@copybara-service copybara-service bot merged commit a5eebe5 into main May 17, 2025
18 of 19 checks passed
@copybara-service copybara-service bot deleted the mutable-array-p2 branch May 17, 2025 00:25
@ravwojdyla
Copy link

👋 @cgarciae when I tried the 1st toy example (Functional API) using latest release 0.10.6 (directed from https://flax.readthedocs.io/en/latest/examples/core_examples.html#toy-examples). It failed with:

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

On line:

That line was changed in this PR. I noticed the same operation in the 2nd and 3rd example are ATM unchanged:

I can see that you have the part 3 PR in #4755, which is not updating the 2nd and 3rd examples.

Some questions for you:

  • the failure on the latest release is as expected?
  • should the examples 2 and 3 also be updated?

Btw, pending your response/direction, I'm happy to help 🙏

@Tomas542
Copy link

@ravwojdyla, hi. Did you installed flax from PyPI or build it from source? I think, Mutable Arrays changes not on the PyPI builds for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants