We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e224c3d commit 82611ebCopy full SHA for 82611eb
jax/_src/api.py
@@ -336,6 +336,8 @@ def disable_jit(disable: bool = True):
336
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
337
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
338
and any other case where :func:`jit` is used within an API's implementation.
339
+ Note however that even under `disable_jit`, individual primitive operations
340
+ will still be compiled by XLA as in normal eager op-by-op execution.
341
342
Values that have a data dependence on the arguments to a jitted function are
343
traced and abstracted. For example, an abstract value may be a
0 commit comments