diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index 6f3140a58..a5d0ba8a1 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -39,7 +39,7 @@ def test_encode(self, vocab_size, embed_dim, inputs, expected): embed_dim=embed_dim, rngs=nnx.Rngs(params=0), ) - embedder.input_embedding.value = jnp.ones((vocab_size, embed_dim)) + embedder.input_embedding[...] = jnp.ones((vocab_size, embed_dim)) output = embedder.encode(inputs) np.testing.assert_array_equal(output, jnp.array(expected)) @@ -57,7 +57,7 @@ def test_decode(self, vocab_size, embed_dim, inputs, expected): embed_dim=embed_dim, rngs=nnx.Rngs(params=0), ) - embedder.input_embedding.value = jnp.ones((vocab_size, embed_dim)) + embedder.input_embedding[...] = jnp.ones((vocab_size, embed_dim)) output = embedder.decode(jnp.array(inputs)) np.testing.assert_array_equal(output, jnp.array(expected)) @@ -228,9 +228,9 @@ def test_ffw( hidden_dim=hidden_dim, rngs=nnx.Rngs(params=0), ) - ffw.gate_proj.kernel.value = jnp.ones((features, hidden_dim)) - ffw.up_proj.kernel.value = jnp.ones((features, hidden_dim)) - ffw.down_proj.kernel.value = jnp.ones((hidden_dim, features)) + ffw.gate_proj.kernel[...] = jnp.ones((features, hidden_dim)) + ffw.up_proj.kernel[...] = jnp.ones((features, hidden_dim)) + ffw.down_proj.kernel[...] = jnp.ones((hidden_dim, features)) with jax.default_matmul_precision('float32'): outputs = ffw(inputs)