Skip to content

Commit 993c2c6

Browse files
committed
Numba dispatch of ScalarLoop
1 parent 23bbabf commit 993c2c6

File tree

3 files changed

+156
-11
lines changed

3 files changed

+156
-11
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.link.utils import (
1616
get_name_for_object,
1717
)
18+
from pytensor.scalar import ScalarLoop
1819
from pytensor.scalar.basic import (
1920
Add,
2021
Cast,
@@ -364,3 +365,52 @@ def softplus(x):
364365
return numba_basic.direct_cast(value, out_dtype)
365366

366367
return softplus, scalar_op_cache_key(op, cache_version=1)
368+
369+
370+
@register_funcify_and_cache_key(ScalarLoop)
371+
def numba_funcify_ScalarLoop(op, node, **kwargs):
372+
inner_fn, inner_fn_cache_key = numba_funcify_and_cache_key(op.fgraph)
373+
if inner_fn_cache_key is None:
374+
loop_cache_key = None
375+
else:
376+
loop_cache_key = sha256(
377+
str((type(op), op.is_while, inner_fn_cache_key)).encode()
378+
).hexdigest()
379+
380+
if op.is_while:
381+
n_update = len(op.outputs) - 1
382+
383+
@numba_basic.numba_njit
384+
def while_loop(n_steps, *inputs):
385+
carry, constant = inputs[:n_update], inputs[n_update:]
386+
387+
until = False
388+
for i in range(n_steps):
389+
outputs = inner_fn(*carry, *constant)
390+
carry, until = outputs[:-1], outputs[-1]
391+
if until:
392+
break
393+
394+
return *carry, until
395+
396+
return while_loop, loop_cache_key
397+
398+
else:
399+
n_update = len(op.outputs)
400+
401+
@numba_basic.numba_njit
402+
def for_loop(n_steps, *inputs):
403+
carry, constant = inputs[:n_update], inputs[n_update:]
404+
405+
if n_steps < 0:
406+
raise ValueError("ScalarLoop does not have a termination condition.")
407+
408+
for i in range(n_steps):
409+
carry = inner_fn(*carry, *constant)
410+
411+
if n_update == 1:
412+
return carry[0]
413+
else:
414+
return carry
415+
416+
return for_loop, loop_cache_key

tests/link/numba/test_elemwise.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -609,18 +609,42 @@ def test_elemwise_multiple_inplace_outs():
609609

610610

611611
def test_scalar_loop():
612-
a = float64("a")
613-
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
612+
a_scalar = float64("a")
613+
const_scalar = float64("const")
614+
scalar_loop = pytensor.scalar.ScalarLoop(
615+
init=[a_scalar],
616+
update=[a_scalar + a_scalar + const_scalar],
617+
constant=[const_scalar],
618+
)
614619

615-
x = pt.tensor("x", shape=(3,))
616-
elemwise_loop = Elemwise(scalar_loop)(3, x)
620+
a = pt.tensor("a", shape=(3,))
621+
const = pt.tensor("const", shape=(3,))
622+
n_steps = 3
623+
elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const)
617624

618-
with pytest.warns(UserWarning, match="object mode"):
619-
compare_numba_and_py(
620-
[x],
621-
[elemwise_loop],
622-
(np.array([1, 2, 3], dtype="float64"),),
623-
)
625+
compare_numba_and_py(
626+
[a, const],
627+
[elemwise_loop],
628+
[np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")],
629+
)
630+
631+
632+
def test_gammainc_wrt_k_grad():
633+
x = pt.vector("x", dtype="float64")
634+
k = pt.vector("k", dtype="float64")
635+
636+
out = pt.gammainc(k, x)
637+
grad_out = grad(out.sum(), k)
638+
639+
compare_numba_and_py(
640+
[x, k],
641+
[grad_out],
642+
# These values of x and k trigger all the branches in the gradient of gammainc
643+
[
644+
np.array([0.0, 29.0, 31.0], dtype="float64"),
645+
np.array([1.0, 13.0, 11.0], dtype="float64"),
646+
],
647+
)
624648

625649

626650
class TestsBenchmark:

tests/link/numba/test_scalar.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytensor.tensor as pt
99
from pytensor import config, function
1010
from pytensor.graph import Apply
11-
from pytensor.scalar import UnaryScalarOp
11+
from pytensor.scalar import ScalarLoop, UnaryScalarOp
1212
from pytensor.scalar.basic import Composite
1313
from pytensor.tensor import tensor
1414
from pytensor.tensor.elemwise import Elemwise
@@ -231,3 +231,74 @@ def test_erf_complex():
231231
[g],
232232
[np.array(0.5 + 1j, dtype="complex128")],
233233
)
234+
235+
236+
class TestScalarLoop:
237+
def test_scalar_for_loop_single_out(self):
238+
n_steps = ps.int64("n_steps")
239+
x0 = ps.float64("x0")
240+
const = ps.float64("const")
241+
x = x0 + const
242+
243+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
244+
x = op(n_steps, x0, const)
245+
246+
fn = function([n_steps, x0, const], [x], mode=numba_mode)
247+
248+
res_x = fn(n_steps=5, x0=0, const=1)
249+
np.testing.assert_allclose(res_x, 5)
250+
251+
res_x = fn(n_steps=5, x0=0, const=2)
252+
np.testing.assert_allclose(res_x, 10)
253+
254+
res_x = fn(n_steps=4, x0=3, const=-1)
255+
np.testing.assert_allclose(res_x, -1)
256+
257+
def test_scalar_for_loop_multiple_outs(self):
258+
n_steps = ps.int64("n_steps")
259+
x0 = ps.float64("x0")
260+
y0 = ps.int64("y0")
261+
const = ps.float64("const")
262+
x = x0 + const
263+
y = y0 + 1
264+
265+
op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y])
266+
x, y = op(n_steps, x0, y0, const)
267+
268+
fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode)
269+
270+
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1)
271+
np.testing.assert_allclose(res_x, 5)
272+
np.testing.assert_allclose(res_y, 5)
273+
274+
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2)
275+
np.testing.assert_allclose(res_x, 10)
276+
np.testing.assert_allclose(res_y, 5)
277+
278+
res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1)
279+
np.testing.assert_allclose(res_x, -1)
280+
np.testing.assert_allclose(res_y, 6)
281+
282+
def test_scalar_while_loop(self):
283+
n_steps = ps.int64("n_steps")
284+
x0 = ps.float64("x0")
285+
x = x0 + 1
286+
until = x >= 10
287+
288+
op = ScalarLoop(init=[x0], update=[x], until=until)
289+
fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode)
290+
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
291+
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
292+
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
293+
np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False])
294+
295+
def test_loop_with_cython_wrapped_op(self):
296+
x = ps.float64("x")
297+
op = ScalarLoop(init=[x], update=[ps.psi(x)])
298+
out = op(1, x)
299+
300+
fn = function([x], out, mode=numba_mode)
301+
x_test = np.float64(0.5)
302+
res = fn(x_test)
303+
expected_res = ps.psi(x).eval({x: x_test})
304+
np.testing.assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)