File tree Expand file tree Collapse file tree 3 files changed +23
-3
lines changed
jax/experimental/key_reuse Expand file tree Collapse file tree 3 files changed +23
-3
lines changed Original file line number Diff line number Diff line change @@ -211,8 +211,11 @@ def _pjit_key_type_signature(eqn, args_consumed):
211211 jaxpr = eqn .params ['jaxpr' ]
212212 forwarded_inputs = {i : eqn .invars .index (var ) for i , var in enumerate (eqn .invars )
213213 if var in eqn .invars [:i ]}
214- return get_jaxpr_type_signature (jaxpr .jaxpr , consumed_inputs = args_consumed ,
215- forwarded_inputs = forwarded_inputs )
214+ sig = get_jaxpr_type_signature (jaxpr .jaxpr )
215+ if args_consumed and any (np .any (args_consumed [s .idx ] & s .mask ) for s in sig .sinks ):
216+ # Double consumption detected: re-trace with context for better errors.
217+ get_jaxpr_type_signature (jaxpr .jaxpr , args_consumed , forwarded_inputs )
218+ return sig
216219
217220key_reuse_signatures_dynamic [pjit .pjit_p ] = _pjit_key_type_signature
218221
Original file line number Diff line number Diff line change @@ -183,7 +183,11 @@ def _pjit_key_type_signature(eqn, args_consumed):
183183 non_literal_invars = [v for v in eqn .invars if not isinstance (v , core .Literal )]
184184 if len (set (non_literal_invars )) != len (non_literal_invars ):
185185 raise ValueError (f"pjit with duplicate inputs: { eqn .invars = } " )
186- return get_jaxpr_type_signature (jaxpr .jaxpr , consumed_inputs = args_consumed )
186+ sig = get_jaxpr_type_signature (jaxpr .jaxpr )
187+ if args_consumed and any (np .any (args_consumed [s .idx ] & s .mask ) for s in sig .sinks ):
188+ # Double consumption detected: re-trace with context for better errors.
189+ get_jaxpr_type_signature (jaxpr .jaxpr , args_consumed )
190+ return sig
187191
188192key_reuse_signatures_dynamic [pjit .pjit_p ] = _pjit_key_type_signature
189193
Original file line number Diff line number Diff line change @@ -573,6 +573,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
573573 random_bits_error = "In random_bits, key values .+ are already consumed.*"
574574 random_split_error = "In random_split, key values .+ are already consumed.*"
575575 generic_error = ".*key values .+ are already consumed.*"
576+ pjit_error = "In pjit, key values a are already consumed."
576577
577578 def check_key_reuse (self , f , * args ):
578579 if self .use_forwarding :
@@ -782,6 +783,18 @@ def body_fun(i):
782783 with self .assertRaisesRegex (KeyReuseError , "while_loop cond function leads to key reuse" ):
783784 self .check_key_reuse (f , 0 )
784785
786+ def test_pjit_consumed_input (self ):
787+ @jax .jit
788+ def g (key , x ): # doesn't consume key
789+ return x
790+
791+ def f (seed ):
792+ key = jax .random .key (seed )
793+ x = jax .random .bits (key )
794+ return g (key , x )
795+
796+ self .check_key_reuse (f , 0 )
797+
785798
786799class KeyReuseIntegrationTestSimple (KeyReuseIntegrationTest ):
787800 use_forwarding = False
You can’t perform that action at this time.
0 commit comments