@@ -39,7 +39,7 @@ def test_encode(self, vocab_size, embed_dim, inputs, expected):
3939 embed_dim = embed_dim ,
4040 rngs = nnx .Rngs (params = 0 ),
4141 )
42- embedder .input_embedding . value = jnp .ones ((vocab_size , embed_dim ))
42+ embedder .input_embedding [...] = jnp .ones ((vocab_size , embed_dim ))
4343 output = embedder .encode (inputs )
4444 np .testing .assert_array_equal (output , jnp .array (expected ))
4545
@@ -57,7 +57,7 @@ def test_decode(self, vocab_size, embed_dim, inputs, expected):
5757 embed_dim = embed_dim ,
5858 rngs = nnx .Rngs (params = 0 ),
5959 )
60- embedder .input_embedding . value = jnp .ones ((vocab_size , embed_dim ))
60+ embedder .input_embedding [...] = jnp .ones ((vocab_size , embed_dim ))
6161 output = embedder .decode (jnp .array (inputs ))
6262 np .testing .assert_array_equal (output , jnp .array (expected ))
6363
@@ -228,9 +228,9 @@ def test_ffw(
228228 hidden_dim = hidden_dim ,
229229 rngs = nnx .Rngs (params = 0 ),
230230 )
231- ffw .gate_proj .kernel . value = jnp .ones ((features , hidden_dim ))
232- ffw .up_proj .kernel . value = jnp .ones ((features , hidden_dim ))
233- ffw .down_proj .kernel . value = jnp .ones ((hidden_dim , features ))
231+ ffw .gate_proj .kernel [...] = jnp .ones ((features , hidden_dim ))
232+ ffw .up_proj .kernel [...] = jnp .ones ((features , hidden_dim ))
233+ ffw .down_proj .kernel [...] = jnp .ones ((hidden_dim , features ))
234234
235235 with jax .default_matmul_precision ('float32' ):
236236 outputs = ffw (inputs )
0 commit comments