Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit ad7a8b7

Browse files
Smit-createbrandonwillard
authored andcommitted
Implement a Convolve Op
1 parent 5ec04b9 commit ad7a8b7

3 files changed

Lines changed: 201 additions & 2 deletions

File tree

aesara/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3811,7 +3811,7 @@ def make_node(self, a, choices):
38113811
static_out_shape = ()
38123812
for s in out_shape:
38133813
try:
3814-
s_val = aesara.get_scalar_constant_value(s)
3814+
s_val = get_scalar_constant_value(s)
38153815
except (NotScalarConstantError, AttributeError):
38163816
s_val = None
38173817

aesara/tensor/math.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import builtins
22
import warnings
3-
from typing import TYPE_CHECKING, List, Optional, Tuple
3+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple
44

55
import numpy as np
66

@@ -22,10 +22,12 @@
2222
cast,
2323
concatenate,
2424
constant,
25+
get_scalar_constant_value,
2526
stack,
2627
switch,
2728
)
2829
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
30+
from aesara.tensor.exceptions import NotScalarConstantError
2931
from aesara.tensor.shape import shape, specify_broadcastable
3032
from aesara.tensor.type import (
3133
DenseTensorType,
@@ -46,6 +48,8 @@
4648
if TYPE_CHECKING:
4749
from numpy.typing import ArrayLike, DTypeLike
4850

51+
from aesara.tensor.var import TensorVariable
52+
4953
# We capture the builtins that we are going to replace to follow the numpy API
5054
_abs = builtins.abs
5155

@@ -2998,6 +3002,108 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29983002
return MatMul(dtype=dtype)(x1, x2)
29993003

30003004

3005+
class Convolve(Op):
3006+
__props__ = ("mode",)
3007+
3008+
def __init__(self, mode="full"):
3009+
self.mode = mode
3010+
3011+
@classmethod
3012+
def _get_output_shape(cls, a, v, shapes, mode, validate=False):
3013+
a_shape, v_shape = shapes
3014+
from aesara.tensor.math import maximum, minimum
3015+
3016+
if a.ndim == 1 and v.ndim == 1:
3017+
m, n = a_shape[0], v_shape[0]
3018+
if n is None or m is None:
3019+
return (None,)
3020+
if mode == "full":
3021+
return (m + n - 1,)
3022+
elif mode == "same":
3023+
return (maximum(m, n),)
3024+
elif mode == "valid":
3025+
return (maximum(m, n) - minimum(m, n) + 1,)
3026+
if validate:
3027+
raise ValueError("Invalid mode - must be full, valid or same")
3028+
return ()
3029+
else:
3030+
if validate:
3031+
raise ValueError("`a` and `v` must be 1-dim.")
3032+
return ()
3033+
3034+
def make_node(self, a, v):
3035+
a = as_tensor_variable(a)
3036+
v = as_tensor_variable(v)
3037+
3038+
if a.ndim != 1 or v.ndim != 1:
3039+
raise ValueError("inputs to `convolve` must be 1-dim.")
3040+
3041+
out_shape = self._get_output_shape(
3042+
a, v, (a.type.shape, v.type.shape), self.mode, validate=True
3043+
)
3044+
3045+
static_out_shape = ()
3046+
for s in out_shape:
3047+
try:
3048+
s_val = get_scalar_constant_value(s)
3049+
except (NotScalarConstantError, AttributeError):
3050+
s_val = None
3051+
3052+
if s_val:
3053+
static_out_shape += (s_val,)
3054+
else:
3055+
static_out_shape += (None,)
3056+
3057+
out = TensorType(
3058+
aes.upcast(a.type.dtype, v.type.dtype), shape=static_out_shape
3059+
)()
3060+
return Apply(self, [a, v], [out])
3061+
3062+
def perform(self, node, inputs, outputs):
3063+
a, v = inputs
3064+
outputs[0][0] = np.convolve(a, v, mode=self.mode)
3065+
3066+
def infer_shape(self, fgraph, node, shapes):
3067+
a, v = node.inputs
3068+
return [self._get_output_shape(a, v, shapes, self.mode)]
3069+
3070+
3071+
def convolve(
3072+
a: "ArrayLike", v: "ArrayLike", mode: Literal["full", "same", "valid"] = "full"
3073+
) -> "TensorVariable":
3074+
"""Compute the discrete, linear convolution of two one-dimensional sequences.
3075+
3076+
Parameters
3077+
----------
3078+
a, v
3079+
Input arrays, both should be one dimensional.
3080+
mode
3081+
'full':
3082+
By default, mode is 'full'. This returns the convolution
3083+
at each point of overlap.
3084+
3085+
'same':
3086+
Mode 'same'. Boundary effects are still visible.
3087+
3088+
'valid':
3089+
The convolution product is only given for points
3090+
where the signals overlap completely.
3091+
Values outside the signal boundary have no effect.
3092+
3093+
Returns
3094+
-------
3095+
out
3096+
Discrete, linear convolution of a and v.
3097+
3098+
Raises
3099+
------
3100+
ValueError
3101+
If the a and v are not one-dimensional.
3102+
3103+
"""
3104+
return Convolve(mode=mode)(a, v)
3105+
3106+
30013107
__all__ = [
30023108
"max_and_argmax",
30033109
"max",
@@ -3126,6 +3232,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31263232
"logsumexp",
31273233
"hyp2f1",
31283234
"hyp2f1_der",
3235+
"convolve",
31293236
]
31303237

31313238
DEPRECATED_NAMES: List[Tuple[str, str, object]] = [

tests/tensor/test_math.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from aesara.tensor.elemwise import CAReduce, Elemwise
3434
from aesara.tensor.math import (
3535
Argmax,
36+
Convolve,
3637
Dot,
3738
MatMul,
3839
MaxAndArgmax,
@@ -59,6 +60,7 @@
5960
clip,
6061
complex_from_polar,
6162
conj,
63+
convolve,
6264
cos,
6365
cosh,
6466
cov,
@@ -3576,3 +3578,93 @@ def test_deprecations():
35763578

35773579
with pytest.deprecated_call():
35783580
from aesara.tensor.math import sqr # noqa: F401 F811
3581+
3582+
3583+
class TestConvolve(utt.InferShapeTester):
3584+
@pytest.mark.parametrize(
3585+
"a, v, mode",
3586+
[
3587+
(
3588+
np.arange(3).astype(config.floatX),
3589+
np.arange(6).astype(config.floatX),
3590+
"full",
3591+
),
3592+
(
3593+
np.arange(5).astype(config.floatX),
3594+
np.arange(3).astype(config.floatX),
3595+
"valid",
3596+
),
3597+
(
3598+
np.arange(8).astype(config.floatX),
3599+
np.arange(9).astype(config.floatX),
3600+
"same",
3601+
),
3602+
(np.arange(3).astype(int), np.arange(6).astype(int), "full"),
3603+
(
3604+
np.random.normal(size=4).astype(config.floatX),
3605+
np.random.uniform(size=8).astype(int),
3606+
"valid",
3607+
),
3608+
],
3609+
)
3610+
def test_op(self, a, v, mode):
3611+
aesara_sol = convolve(a, v, mode=mode).eval()
3612+
numpy_sol = np.convolve(a, v, mode=mode)
3613+
assert np.allclose(numpy_sol, aesara_sol)
3614+
3615+
def test_scalar_error(self):
3616+
with pytest.raises(ValueError, match="must be 1-dim"):
3617+
convolve(4, [4, 1])
3618+
3619+
@pytest.mark.parametrize(
3620+
"a_shape, v_shape, mode, error_regex",
3621+
[
3622+
((2,), (3, 1), "full", "must be 1-dim"),
3623+
((2,), (3,), "val", "must be full, valid or same"),
3624+
],
3625+
)
3626+
def test_get_output_shape_error(self, a_shape, v_shape, mode, error_regex):
3627+
a = tensor(dtype=np.float64, shape=a_shape)
3628+
v = tensor(dtype=np.float64, shape=v_shape)
3629+
3630+
with pytest.raises(ValueError, match=error_regex):
3631+
Convolve._get_output_shape(a, v, (a_shape, v_shape), mode, validate=True)
3632+
assert (
3633+
Convolve._get_output_shape(a, v, (a_shape, v_shape), mode, validate=False)
3634+
== ()
3635+
)
3636+
3637+
@pytest.mark.parametrize(
3638+
"a_shape, v_shape, a1_shape, v1_shape",
3639+
[
3640+
((5,), (5,), (5,), (5,)),
3641+
((1,), (5,), (1,), (5,)),
3642+
((1,), (1,), (1,), (1,)),
3643+
((3,), (1,), (3,), (1,)),
3644+
((None,), (3,), (4,), (3,)),
3645+
((None,), (None,), (4,), (2,)),
3646+
((2,), (None,), (2,), (3,)),
3647+
],
3648+
)
3649+
@pytest.mark.parametrize(
3650+
"mode",
3651+
[
3652+
"full",
3653+
"valid",
3654+
"same",
3655+
],
3656+
)
3657+
def test_infer_shape(self, a_shape, v_shape, a1_shape, v1_shape, mode):
3658+
a = tensor(dtype=config.floatX, shape=a_shape)
3659+
v = tensor(dtype=config.floatX, shape=v_shape)
3660+
3661+
rng = np.random.default_rng(utt.fetch_seed())
3662+
a1 = rng.random(a1_shape).astype(config.floatX)
3663+
v1 = rng.random(v1_shape).astype(config.floatX)
3664+
3665+
self._compile_and_check(
3666+
[a, v],
3667+
[convolve(a, v, mode)],
3668+
[a1, v1],
3669+
Convolve,
3670+
)

0 commit comments

Comments
 (0)