Skip to content

Commit 69a9f7f

Browse files
author
jax authors
committed
Merge pull request #19629 from jakevdp:key-reuse-pjit
PiperOrigin-RevId: 604404276
2 parents 206398a + f453442 commit 69a9f7f

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

jax/experimental/key_reuse/_forwarding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,11 @@ def _pjit_key_type_signature(eqn, args_consumed):
211211
jaxpr = eqn.params['jaxpr']
212212
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
213213
if var in eqn.invars[:i]}
214-
return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed,
215-
forwarded_inputs=forwarded_inputs)
214+
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
215+
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
216+
# Double consumption detected: re-trace with context for better errors.
217+
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed, forwarded_inputs)
218+
return sig
216219

217220
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
218221

jax/experimental/key_reuse/_simple.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,11 @@ def _pjit_key_type_signature(eqn, args_consumed):
183183
non_literal_invars = [v for v in eqn.invars if not isinstance(v, core.Literal)]
184184
if len(set(non_literal_invars)) != len(non_literal_invars):
185185
raise ValueError(f"pjit with duplicate inputs: {eqn.invars=}")
186-
return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed)
186+
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
187+
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
188+
# Double consumption detected: re-trace with context for better errors.
189+
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed)
190+
return sig
187191

188192
key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature
189193

tests/key_reuse_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
573573
random_bits_error = "In random_bits, key values .+ are already consumed.*"
574574
random_split_error = "In random_split, key values .+ are already consumed.*"
575575
generic_error = ".*key values .+ are already consumed.*"
576+
pjit_error = "In pjit, key values a are already consumed."
576577

577578
def check_key_reuse(self, f, *args):
578579
if self.use_forwarding:
@@ -782,6 +783,18 @@ def body_fun(i):
782783
with self.assertRaisesRegex(KeyReuseError, "while_loop cond function leads to key reuse"):
783784
self.check_key_reuse(f, 0)
784785

786+
def test_pjit_consumed_input(self):
787+
@jax.jit
788+
def g(key, x): # doesn't consume key
789+
return x
790+
791+
def f(seed):
792+
key = jax.random.key(seed)
793+
x = jax.random.bits(key)
794+
return g(key, x)
795+
796+
self.check_key_reuse(f, 0)
797+
785798

786799
class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest):
787800
use_forwarding = False

0 commit comments

Comments
 (0)