File tree Expand file tree Collapse file tree 3 files changed +25
-10
lines changed
jax/experimental/key_reuse Expand file tree Collapse file tree 3 files changed +25
-10
lines changed Original file line number Diff line number Diff line change @@ -260,16 +260,20 @@ def _scan_key_type_signature(eqn, args_consumed):
260260
261261 # scan body should not consume key in constants
262262 if any (np .any (s .mask ) for s in signature .sinks if s .idx < num_consts ):
263- raise KeyReuseError ("scan body function leads to key reuse when repeatedly executed:\n "
263+ raise KeyReuseError ("scan body function leads to key reuse when repeatedly executed, "
264+ "because key constants are repeatedly consumed:\n "
264265 f" { signature = } \n "
265266 f" { eqn = } \n "
266267 f" { jaxpr = } " )
267268
268269 # scan carry should only consume keys that are sourced on output.
269- carry_sinks = {s .idx - num_consts : s .mask for s in signature .sinks if 0 <= s .idx - num_consts < num_carry }
270- carry_sources = {s .idx : s .mask for s in signature .sources if 0 <= s .idx < num_carry }
271- if carry_sinks .keys () != carry_sources .keys (): # TODO(jakevdp): check that masks match
272- raise KeyReuseError (f"scan body function leads to key reuse when repeatedly executed:\n "
270+ carry_sinks = {s .idx - num_consts : s .mask for s in signature .sinks
271+ if 0 <= s .idx - num_consts < num_carry and np .any (s .mask )}
272+ carry_sources = {s .idx : s .mask for s in signature .sources
273+ if 0 <= s .idx < num_carry and np .any (s .mask )}
274+ if not set (carry_sinks ).issubset (set (carry_sources )): # TODO(jakevdp): check that masks match
275+ raise KeyReuseError ("scan body function leads to key reuse when repeatedly executed, "
276+ "because consumed inputs don't match sourced outputs:\n "
273277 f" { signature = } \n "
274278 f" { eqn = } \n "
275279 f" { jaxpr = } " )
Original file line number Diff line number Diff line change @@ -231,16 +231,20 @@ def _scan_key_type_signature(eqn, args_consumed):
231231
232232 # scan body should not consume key in constants
233233 if any (np .any (s .mask ) for s in signature .sinks if s .idx < num_consts ):
234- raise KeyReuseError (f"scan body function leads to key reuse when repeatedly executed:\n "
234+ raise KeyReuseError ("scan body function leads to key reuse when repeatedly executed, "
235+ "because key constants are repeatedly consumed:\n "
235236 f" { signature = } \n "
236237 f" { eqn = } \n "
237238 f" { jaxpr = } " )
238239
239240 # scan carry should only consume keys that are sourced on output.
240- carry_sinks = {s .idx - num_consts : s .mask for s in signature .sinks if 0 <= s .idx - num_consts < num_carry }
241- carry_sources = {s .idx : s .mask for s in signature .sources if 0 <= s .idx < num_carry }
242- if carry_sinks .keys () != carry_sources .keys (): # TODO(jakevdp): check that masks match
243- raise KeyReuseError (f"scan body function leads to key reuse when repeatedly executed:\n "
241+ carry_sinks = {s .idx - num_consts : s .mask for s in signature .sinks
242+ if 0 <= s .idx - num_consts < num_carry and np .any (s .mask )}
243+ carry_sources = {s .idx : s .mask for s in signature .sources
244+ if 0 <= s .idx < num_carry and np .any (s .mask )}
245+ if not set (carry_sinks ).issubset (set (carry_sources )):
246+ raise KeyReuseError ("scan body function leads to key reuse when repeatedly executed, "
247+ "because consumed inputs don't match sourced outputs:\n "
244248 f" { signature = } \n "
245249 f" { eqn = } \n "
246250 f" { jaxpr = } " )
Original file line number Diff line number Diff line change @@ -710,6 +710,13 @@ def f_scan_over_keys(key):
710710 return jax .lax .map (jax .random .bits , keys )
711711 self .check_key_reuse (f_scan_over_keys , jax .random .key (0 ))
712712
713+ def test_scan_consume_one (self ):
714+ def f_scan_over_keys (* keys ):
715+ def body_func (keys , x ):
716+ return tuple (jax .random .split (keys [0 ])), x
717+ return jax .lax .scan (body_func , keys , xs = jnp .arange (10 ))
718+ self .check_key_reuse (f_scan_over_keys , jax .random .key (0 ), jax .random .key (1 ))
719+
713720 def test_vmap (self ):
714721 @jax .vmap
715722 def f_good (seed ):
You can’t perform that action at this time.
0 commit comments