Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@


if TYPE_CHECKING:
from aesara.graph.basic import Variable
from aesara.graph.op import StorageMapType


Expand Down Expand Up @@ -79,13 +80,21 @@ def numba_vectorize(*args, **kwargs):


@singledispatch
def get_numba_type(aesara_type: Type, **kwargs) -> numba.types.Type:
r"""Create a Numba type object for a :class:`Type`."""
def get_numba_type(aesara_type: Type, var: "Variable", **kwargs) -> numba.types.Type:
r"""Create a Numba type object for a :class:`Type`.

Parameters
----------
aesara_type
The :class:`Type` to convert.
var
The :class:`Variable` corresponding to `aesara_type`.
"""
return numba.types.pyobject


@get_numba_type.register(ScalarType)
def get_numba_type_ScalarType(aesara_type, **kwargs):
def get_numba_type_ScalarType(aesara_type, var, **kwargs):
dtype = np.dtype(aesara_type.dtype)
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
Expand All @@ -94,6 +103,7 @@ def get_numba_type_ScalarType(aesara_type, **kwargs):
@get_numba_type.register(TensorType)
def get_numba_type_TensorType(
aesara_type,
var: "Variable",
layout: str = "A",
force_scalar: bool = False,
reduce_to_scalar: bool = False,
Expand All @@ -103,6 +113,8 @@ def get_numba_type_TensorType(
----------
aesara_type
The :class:`Type` to convert.
var
The :class:`Variable` corresponding to `aesara_type`.
layout
The :class:`numpy.ndarray` layout to use.
force_scalar
Expand All @@ -114,7 +126,10 @@ def get_numba_type_TensorType(
numba_dtype = numba.from_dtype(dtype)
if force_scalar or (reduce_to_scalar and getattr(aesara_type, "ndim", None) == 0):
return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)

readonly = getattr(var.tag, "indestructible", False)

return numba.types.Array(numba_dtype, aesara_type.ndim, layout, readonly=readonly)


def create_numba_signature(
Expand All @@ -123,11 +138,11 @@ def create_numba_signature(
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
input_types = []
for inp in node_or_fgraph.inputs:
input_types.append(get_numba_type(inp.type, **kwargs))
input_types.append(get_numba_type(inp.type, inp, **kwargs))

output_types = []
for out in node_or_fgraph.outputs:
output_types.append(get_numba_type(out.type, **kwargs))
output_types.append(get_numba_type(out.type, inp, **kwargs))

if isinstance(node_or_fgraph, FunctionGraph):
return numba.types.Tuple(output_types)(*input_types)
Expand Down Expand Up @@ -379,9 +394,9 @@ def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable:
n_outputs = len(node.outputs)

if n_outputs > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
ret_sig = numba.types.Tuple([get_numba_type(o.type, o) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

output_types = tuple(out.type for out in node.outputs)
params = node.run_params()
Expand Down Expand Up @@ -821,7 +836,7 @@ def cholesky(a):
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_njit
def cholesky(a):
Expand Down Expand Up @@ -850,7 +865,7 @@ def numba_funcify_Solve(op, node, **kwargs):
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_njit
def solve(a, b):
Expand Down
10 changes: 6 additions & 4 deletions aesara/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_basic.numba_njit
def repeatop(x, repeats):
Expand Down Expand Up @@ -243,9 +243,11 @@ def unique(x):
)

if returns_multi:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
ret_sig = numba.types.Tuple(
[get_numba_type(o.type, o) for o in node.outputs]
)
else:
ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_basic.numba_njit
def unique(x):
Expand Down Expand Up @@ -308,7 +310,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_basic.numba_njit
def searchsorted(a, v, sorter):
Expand Down
13 changes: 9 additions & 4 deletions aesara/link/numba/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def numba_funcify_SVD(op, node, **kwargs):
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_basic.numba_njit
def svd(x):
Expand Down Expand Up @@ -101,7 +101,10 @@ def numba_funcify_Eigh(op, node, **kwargs):

out_dtypes = tuple(o.type.numpy_dtype for o in node.outputs)
ret_sig = numba.types.Tuple(
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
[
get_numba_type(node.outputs[0].type, node.outputs[0]),
get_numba_type(node.outputs[1].type, node.outputs[1]),
]
)

@numba_basic.numba_njit
Expand Down Expand Up @@ -173,9 +176,11 @@ def numba_funcify_QRFull(op, node, **kwargs):
)

if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
ret_sig = numba.types.Tuple(
[get_numba_type(o.type, o) for o in node.outputs]
)
else:
ret_sig = get_numba_type(node.outputs[0].type)
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])

@numba_basic.numba_njit
def qr_full(x):
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/numba/dispatch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def copy(inst):


@get_numba_type.register(SparseTensorType)
def get_numba_type_SparseType(aesara_type, **kwargs):
def get_numba_type_SparseType(aesara_type, var, **kwargs):
dtype = from_dtype(np.dtype(aesara_type.dtype))

if aesara_type.format == "csr":
Expand Down
2 changes: 1 addition & 1 deletion aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3427,7 +3427,7 @@ def profile_printer(
)


@op_debug_information.register(Scan)
@op_debug_information.register(Scan) # type: ignore[has-type]
def _op_debug_information_Scan(op, node):
from typing import Sequence

Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False

[mypy-aesara.scan.op]
warn_unused_ignores = False

[mypy-aesara.link.numba.dispatch.extra_ops]
ignore_errors = True
check_untyped_defs = False
Expand Down
14 changes: 11 additions & 3 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def assert_fn(x, y):


@pytest.mark.parametrize(
"v, expected, force_scalar",
"typ, expected, force_scalar",
[
(MyType(), numba.types.pyobject, False),
(
Expand All @@ -267,11 +267,19 @@ def assert_fn(x, y):
(at.dmatrix, numba.types.float64, True),
],
)
def test_get_numba_type(v, expected, force_scalar):
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
def test_get_numba_type(typ, expected, force_scalar):
res = numba_basic.get_numba_type(typ, typ(), force_scalar=force_scalar)
assert res == expected


def test_get_numba_type_readonly():
typ = at.dmatrix
var = typ()
var.tag.indestructible = True
res = numba_basic.get_numba_type(typ, var)
assert not res.mutable


@pytest.mark.parametrize(
"v, expected, force_scalar",
[
Expand Down
10 changes: 10 additions & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,13 @@ def test_MaxAndArgmax(x, axes, exc):
if not isinstance(i, (SharedVariable, Constant))
],
)


def test_sum_broadcast_to():
"""Make sure that we handle the writability of `BroadcastTo` results correctly."""

x = at.vector("x")
out = at.broadcast_to(x, (2, 2)).sum()

x_val = np.array([1, 2], dtype=config.floatX)
compare_numba_and_py(((x,), (out,)), [x_val])