Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d1a1f50

Browse files
committed
Rewrite size input of RandomVariables in JAX backend
1 parent 58dd489 commit d1a1f50

6 files changed

Lines changed: 70 additions & 5 deletions

File tree

aesara/compile/mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
449449

450450
JAX = Mode(
451451
JAXLinker(),
452-
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
452+
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
453453
)
454454
NUMBA = Mode(
455455
NumbaLinker(),

aesara/link/jax/dispatch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
import aesara.link.jax.dispatch.elemwise
1414
import aesara.link.jax.dispatch.scan
1515

16+
# Load specialized rewrites
17+
import aesara.link.jax.dispatch.rewriting
18+
1619
# isort: on

aesara/link/jax/dispatch/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import aesara.tensor.random.basic as aer
1010
from aesara.link.jax.dispatch.basic import jax_funcify, jax_typify
11+
from aesara.link.jax.dispatch.shape import JAXShapeTuple
1112
from aesara.tensor.shape import Shape, Shape_i
1213

1314

@@ -28,7 +29,7 @@
2829

2930

3031
def assert_size_argument_jax_compatible(node):
31-
"""Assert whether the current node can be compiled.
32+
"""Assert whether the current node can be JIT-compiled by JAX.
3233
3334
JAX can JIT-compile `jax.random` functions when the `size` argument
3435
is a concrete value, i.e. either a constant or the shape of any
@@ -37,7 +38,7 @@ def assert_size_argument_jax_compatible(node):
3738
"""
3839
size = node.inputs[1]
3940
size_op = size.owner.op
40-
if not isinstance(size_op, (Shape, Shape_i)):
41+
if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)):
4142
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
4243

4344

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from aesara.graph.rewriting.basic import node_rewriter, in2out
2+
from aesara.compile import optdb
3+
from aesara.link.jax.dispatch.shape import JAXShapeTuple
4+
from aesara.tensor.basic import MakeVector
5+
from aesara.tensor.elemwise import DimShuffle
6+
from aesara.tensor.random.op import RandomVariable
7+
8+
9+
@node_rewriter([RandomVariable])
10+
def size_parameter_as_tuple(fgraph, node):
11+
if isinstance(node.op, RandomVariable):
12+
size_arg = node.inputs[1]
13+
if isinstance(size_arg.owner.op, MakeVector):
14+
# Here Aesara converted a tuple or list to a tensor
15+
new_size_args = JAXShapeTuple()(*size_arg.owner.inputs)
16+
node.inputs[1] = new_size_args
17+
return node.outputs
18+
19+
if isinstance(size_arg.owner.op, DimShuffle):
20+
# Here Aesara is turning a scalar input into a 1d vector
21+
if (
22+
size_arg.owner.op.input_broadcastable == ()
23+
and size_arg.owner.op.new_order == ("x",)
24+
):
25+
new_size_args = JAXShapeTuple()(*size_arg.owner.inputs)
26+
node.inputs[1] = new_size_args
27+
return node.outputs
28+
29+
30+
optdb.register(
31+
"jax_size_parameter_as_tuple",
32+
in2out(size_parameter_as_tuple),
33+
"jax",
34+
position=100
35+
)

aesara/link/jax/dispatch/shape.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,34 @@
11
import jax.numpy as jnp
22

33
from aesara.graph import Constant
4+
from aesara.graph.basic import Apply
5+
from aesara.graph.op import Op
46
from aesara.link.jax.dispatch.basic import jax_funcify
57
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
8+
from aesara.tensor.type import TensorType
9+
10+
11+
class JAXShapeTuple(Op):
12+
"""Dummy Op that represents a `size` specified as a tuple."""
13+
14+
def __init__(self):
15+
pass
16+
17+
def make_node(self, *inputs):
18+
dtype = inputs[0].type.dtype
19+
otype = TensorType(dtype, shape=(len(inputs),))
20+
return Apply(self, inputs, [otype()])
21+
22+
def perform(self, *inputs):
23+
return tuple(inputs)
24+
25+
26+
@jax_funcify.register(JAXShapeTuple)
27+
def jax_funcify_JAXShapeTuple(op, **kwargs):
28+
def shape_tuple_fn(*x):
29+
return tuple(x)
30+
31+
return shape_tuple_fn
632

733

834
@jax_funcify.register(Reshape)

tests/link/jax/test_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def test_random_concrete_shape():
454454
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
455455

456456

457-
@pytest.mark.xfail(reason="size argument specified as a tuple is a `DimShuffle` node")
457+
# @pytest.mark.xfail(reason="size argument specified as a tuple is a `DimShuffle` node")
458458
def test_random_concrete_shape_subtensor():
459459
rng = shared(np.random.RandomState(123))
460460
x_at = at.dmatrix()
@@ -463,7 +463,7 @@ def test_random_concrete_shape_subtensor():
463463
assert jax_fn(np.ones((2, 3))).shape == (3,)
464464

465465

466-
@pytest.mark.xfail(reason="size argument specified as a tuple is a `MakeVector` node")
466+
# @pytest.mark.xfail(reason="size argument specified as a tuple is a `MakeVector` node")
467467
def test_random_concrete_shape_subtensor_tuple():
468468
rng = shared(np.random.RandomState(123))
469469
x_at = at.dmatrix()

0 commit comments

Comments
 (0)