Skip to content

Commit db0e302

Browse files
author
Flax Authors
committed
Merge pull request #4815 from lukeyeh:push-tpqklpwqzxkq
PiperOrigin-RevId: 785884101
2 parents 8447c03 + 9acdcfd commit db0e302

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

examples/gemma/modules_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_encode(self, vocab_size, embed_dim, inputs, expected):
3939
embed_dim=embed_dim,
4040
rngs=nnx.Rngs(params=0),
4141
)
42-
embedder.input_embedding.value = jnp.ones((vocab_size, embed_dim))
42+
embedder.input_embedding[...] = jnp.ones((vocab_size, embed_dim))
4343
output = embedder.encode(inputs)
4444
np.testing.assert_array_equal(output, jnp.array(expected))
4545

@@ -57,7 +57,7 @@ def test_decode(self, vocab_size, embed_dim, inputs, expected):
5757
embed_dim=embed_dim,
5858
rngs=nnx.Rngs(params=0),
5959
)
60-
embedder.input_embedding.value = jnp.ones((vocab_size, embed_dim))
60+
embedder.input_embedding[...] = jnp.ones((vocab_size, embed_dim))
6161
output = embedder.decode(jnp.array(inputs))
6262
np.testing.assert_array_equal(output, jnp.array(expected))
6363

@@ -228,9 +228,9 @@ def test_ffw(
228228
hidden_dim=hidden_dim,
229229
rngs=nnx.Rngs(params=0),
230230
)
231-
ffw.gate_proj.kernel.value = jnp.ones((features, hidden_dim))
232-
ffw.up_proj.kernel.value = jnp.ones((features, hidden_dim))
233-
ffw.down_proj.kernel.value = jnp.ones((hidden_dim, features))
231+
ffw.gate_proj.kernel[...] = jnp.ones((features, hidden_dim))
232+
ffw.up_proj.kernel[...] = jnp.ones((features, hidden_dim))
233+
ffw.down_proj.kernel[...] = jnp.ones((hidden_dim, features))
234234

235235
with jax.default_matmul_precision('float32'):
236236
outputs = ffw(inputs)

0 commit comments

Comments
 (0)