Skip to content

Commit 1ccd159

Browse files
levskayaFlax Authors
authored andcommitted
Allow using LazyRngs for flax init/apply.
This is occasionally useful for refactoring while preserving rng derivations. PiperOrigin-RevId: 783598096
1 parent e147958 commit 1ccd159

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

flax/core/scope.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,10 @@ def _is_valid_variables(variables: VariableDict) -> bool:
11841184

11851185
def _is_valid_rng(rng: Array):
11861186
"""Checks whether rng is a valid JAX PRNGKey, also handling custom prngs."""
1187+
# Allow for user-provided LazyRng - useful for compatibility when refactoring.
1188+
if isinstance(rng, LazyRng):
1189+
return True
1190+
11871191
# This check is valid for either new-style or old-style PRNG keys
11881192
if not isinstance(rng, (np.ndarray, jnp.ndarray)):
11891193
return False

tests/core/core_scope_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def test_rng_check_w_old_and_new_keys(self):
197197
self.assertTrue(scope._is_valid_rng(raw_key))
198198
self.assertFalse(scope._is_valid_rng(random.split(raw_key)))
199199

200+
def test_rng_check_w_lazy_rng(self):
201+
key = random.key(0)
202+
self.assertTrue(scope._is_valid_rng(scope.LazyRng.create(key, 1)))
203+
200204
def test_jax_leak_detector(self):
201205
with jax.check_tracer_leaks(True):
202206

0 commit comments

Comments
 (0)