Skip to content

Conversation

@cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 11, 2025

What does this PR do?

Adds support for Array leaves in graph nodes.

Example:

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = jax.random.normal(rngs.params(), (din, dout))
    self.b = jnp.zeros((dout,))

  def __call__(self, x):
    return x @ self.w + self.b

@cgarciae cgarciae force-pushed the nnx-array-leaves branch 2 times, most recently from 18d3750 to b524cca Compare March 14, 2025 05:01
@cgarciae cgarciae marked this pull request as ready for review March 15, 2025 06:32
@copybara-service copybara-service bot merged commit 097fffc into main Mar 24, 2025
18 checks passed
@copybara-service copybara-service bot deleted the nnx-array-leaves branch March 24, 2025 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants