Skip to content

Commit c1c0c1c

Browse files
author
jax authors
committed
Merge pull request #19634 from jakevdp:key-reuse-scan
PiperOrigin-RevId: 604418753
2 parents 69a9f7f + f4f8293 commit c1c0c1c

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

jax/experimental/key_reuse/_forwarding.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,20 @@ def _scan_key_type_signature(eqn, args_consumed):
260260

261261
# scan body should not consume key in constants
262262
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
263-
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed:\n"
263+
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
264+
"because key constants are repeatedly consumed:\n"
264265
f" {signature=}\n"
265266
f" {eqn=}\n"
266267
f" {jaxpr=}")
267268

268269
# scan carry should only consume keys that are sourced on output.
269-
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
270-
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
271-
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
272-
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
270+
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks
271+
if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)}
272+
carry_sources = {s.idx: s.mask for s in signature.sources
273+
if 0 <= s.idx < num_carry and np.any(s.mask)}
274+
if not set(carry_sinks).issubset(set(carry_sources)): # TODO(jakevdp): check that masks match
275+
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
276+
"because consumed inputs don't match sourced outputs:\n"
273277
f" {signature=}\n"
274278
f" {eqn=}\n"
275279
f" {jaxpr=}")

jax/experimental/key_reuse/_simple.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,20 @@ def _scan_key_type_signature(eqn, args_consumed):
231231

232232
# scan body should not consume key in constants
233233
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
234-
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
234+
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
235+
"because key constants are repeatedly consumed:\n"
235236
f" {signature=}\n"
236237
f" {eqn=}\n"
237238
f" {jaxpr=}")
238239

239240
# scan carry should only consume keys that are sourced on output.
240-
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
241-
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
242-
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
243-
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
241+
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks
242+
if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)}
243+
carry_sources = {s.idx: s.mask for s in signature.sources
244+
if 0 <= s.idx < num_carry and np.any(s.mask)}
245+
if not set(carry_sinks).issubset(set(carry_sources)):
246+
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
247+
"because consumed inputs don't match sourced outputs:\n"
244248
f" {signature=}\n"
245249
f" {eqn=}\n"
246250
f" {jaxpr=}")

tests/key_reuse_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,13 @@ def f_scan_over_keys(key):
710710
return jax.lax.map(jax.random.bits, keys)
711711
self.check_key_reuse(f_scan_over_keys, jax.random.key(0))
712712

713+
def test_scan_consume_one(self):
714+
def f_scan_over_keys(*keys):
715+
def body_func(keys, x):
716+
return tuple(jax.random.split(keys[0])), x
717+
return jax.lax.scan(body_func, keys, xs=jnp.arange(10))
718+
self.check_key_reuse(f_scan_over_keys, jax.random.key(0), jax.random.key(1))
719+
713720
def test_vmap(self):
714721
@jax.vmap
715722
def f_good(seed):

0 commit comments

Comments
 (0)