File tree Expand file tree Collapse file tree 2 files changed +8
-0
lines changed
Expand file tree Collapse file tree 2 files changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -1184,6 +1184,10 @@ def _is_valid_variables(variables: VariableDict) -> bool:
11841184
11851185def _is_valid_rng (rng : Array ):
11861186 """Checks whether rng is a valid JAX PRNGKey, also handling custom prngs."""
1187+ # Allow for user-provided LazyRng - useful for compatibility when refactoring.
1188+ if isinstance (rng , LazyRng ):
1189+ return True
1190+
11871191 # This check is valid for either new-style or old-style PRNG keys
11881192 if not isinstance (rng , (np .ndarray , jnp .ndarray )):
11891193 return False
Original file line number Diff line number Diff line change @@ -197,6 +197,10 @@ def test_rng_check_w_old_and_new_keys(self):
197197 self .assertTrue (scope ._is_valid_rng (raw_key ))
198198 self .assertFalse (scope ._is_valid_rng (random .split (raw_key )))
199199
200+ def test_rng_check_w_lazy_rng (self ):
201+ key = random .key (0 )
202+ self .assertTrue (scope ._is_valid_rng (scope .LazyRng .create (key , 1 )))
203+
200204 def test_jax_leak_detector (self ):
201205 with jax .check_tracer_leaks (True ):
202206
You can’t perform that action at this time.
0 commit comments