Skip to content

Commit 82611eb

Browse files
committed
document that under disable_jit, individual primitives are still compiled
1 parent e224c3d commit 82611eb

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

jax/_src/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def disable_jit(disable: bool = True):
336336
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
337337
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
338338
and any other case where :func:`jit` is used within an API's implementation.
339+
Note however that even under `disable_jit`, individual primitive operations
340+
will still be compiled by XLA as in normal eager op-by-op execution.
339341
340342
Values that have a data dependence on the arguments to a jitted function are
341343
traced and abstracted. For example, an abstract value may be a

0 commit comments

Comments
 (0)