Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False):
return fill(_model, ret)


def zeros(shape, dtype=None):
def zeros(shape, dtype=None) -> TensorVariable:
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
if not (
isinstance(shape, np.ndarray | Sequence)
Expand All @@ -933,7 +933,7 @@ def zeros(shape, dtype=None):
return alloc(np.array(0, dtype=dtype), *shape)


def ones(shape, dtype=None):
def ones(shape, dtype=None) -> TensorVariable:
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
if not (
isinstance(shape, np.ndarray | Sequence)
Expand Down
204 changes: 168 additions & 36 deletions pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,92 @@
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal
from typing import cast as type_cast

import numpy as np
from numpy import convolve as numpy_convolve
from scipy.signal import convolve as scipy_convolve

from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Constant
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.scalar import as_scalar
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum, switch
from pytensor.tensor.type import vector
from pytensor.tensor.pad import pad
from pytensor.tensor.subtensor import flip
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable


if TYPE_CHECKING:
from pytensor.tensor import TensorLike


class Convolve1d(COp):
class AbstractConvolveNd:
__props__ = ()
gufunc_signature = "(n),(k),()->(o)"
ndim: int

@property
def gufunc_signature(self):
data_signature = ",".join([f"n{i}" for i in range(self.ndim)])
kernel_signature = ",".join([f"k{i}" for i in range(self.ndim)])
output_signature = ",".join([f"o{i}" for i in range(self.ndim)])

return f"({data_signature}),({kernel_signature}),()->({output_signature})"

def make_node(self, in1, in2, full_mode):
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
full_mode = as_scalar(full_mode)

if not (in1.ndim == 1 and in2.ndim == 1):
raise ValueError("Convolution inputs must be vector (ndim=1)")
ndim = self.ndim
if not (in1.ndim == ndim and in2.ndim == self.ndim):
raise ValueError(
f"Convolution inputs must have ndim={ndim}, got: in1={in1.ndim}, in2={in2.ndim}"
)
if not full_mode.dtype == "bool":
raise ValueError("Convolution mode must be a boolean type")
raise ValueError("Convolution full_mode flag must be a boolean type")

dtype = upcast(in1.dtype, in2.dtype)
n = in1.type.shape[0]
k = in2.type.shape[0]
match full_mode:
case Constant():
static_mode = "full" if full_mode.data else "valid"
case _:
static_mode = None

if n is None or k is None or static_mode is None:
out_shape = (None,)
elif static_mode == "full":
out_shape = (n + k - 1,)
else: # mode == "valid":
out_shape = (max(n, k) - min(n, k) + 1,)
if static_mode is None:
out_shape = (None,) * ndim
else:
out_shape = []
# TODO: Raise if static shapes are not valid (one input size doesn't dominate the other)
for n, k in zip(in1.type.shape, in2.type.shape):
if n is None or k is None:
out_shape.append(None)
elif static_mode == "full":
out_shape.append(
n + k - 1,
)
else: # mode == "valid":
out_shape.append(
max(n, k) - min(n, k) + 1,
)
out_shape = tuple(out_shape)

out = vector(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2, full_mode], [out])
dtype = upcast(in1.dtype, in2.dtype)

def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
in1, in2, full_mode = inputs
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
out = tensor(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2, full_mode], [out])

def infer_shape(self, fgraph, node, shapes):
_, _, full_mode = node.inputs
in1_shape, in2_shape, _ = shapes
n = in1_shape[0]
k = in2_shape[0]
shape_valid = maximum(n, k) - minimum(n, k) + 1
shape_full = n + k - 1
shape = switch(full_mode, shape_full, shape_valid)
return [[shape]]
out_shape = [
switch(full_mode, n + k - 1, maximum(n, k) - minimum(n, k) + 1)
for n, k in zip(in1_shape, in2_shape)
]

return [out_shape]

def connection_pattern(self, node):
return [[True], [True], [False]]
Expand All @@ -75,22 +95,34 @@ def L_op(self, inputs, outputs, output_grads):
in1, in2, full_mode = inputs
[grad] = output_grads

n = in1.shape[0]
k = in2.shape[0]
n = in1.shape
k = in2.shape
# Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward)

# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (k >= n))
full_mode_in1_bar = ~full_mode & (k < n)
full_mode_in1_bar = ~full_mode & (k < n).any()
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (n >= k))
full_mode_in2_bar = ~full_mode & (n < k)
full_mode_in2_bar = ~full_mode & (n < k).any()

return [
self(grad, in2[::-1], full_mode_in1_bar),
self(grad, in1[::-1], full_mode_in2_bar),
self(grad, flip(in2), full_mode_in1_bar),
self(grad, flip(in1), full_mode_in2_bar),
DisconnectedType()(),
]


class Convolve1d(AbstractConvolveNd, COp): # type: ignore[misc]
__props__ = ()
ndim = 1

def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
in1, in2, full_mode = inputs
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")

def c_code_cache_version(self):
return (2,)

Expand Down Expand Up @@ -210,4 +242,104 @@ def convolve1d(
mode = "valid"

full_mode = as_scalar(np.bool_(mode == "full"))
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
return type_cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))


class Convolve2d(AbstractConvolveNd, Op): # type: ignore[misc]
__props__ = ("method",) # type: ignore[assignment]
ndim = 2

def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"):
self.method = method

def perform(self, node, inputs, outputs):
in1, in2, full_mode = inputs

# TODO: Why is .item() needed?
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)


def convolve2d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
boundary: Literal["fill", "wrap", "symm"] = "fill",
fillvalue: float | int = 0,
method: Literal["direct", "fft", "auto"] = "auto",
) -> TensorVariable:
"""Convolve two two-dimensional arrays.

Convolve in1 and in2, with the output size determined by the mode argument.

Parameters
----------
in1 : (..., N, M) tensor_like
First input.
in2 : (..., K, L) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
boundary : {'fill', 'wrap', 'symm'}, optional
A string indicating how to handle boundaries:
- 'fill': Pads the input arrays with fillvalue.
- 'wrap': Circularly wraps the input arrays.
- 'symm': Symmetrically reflects the input arrays.
fillvalue : float or int, optional
The value to use for padding when boundary is 'fill'. Default is 0.
method : str, one of 'direct', 'fft', or 'auto'
Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
and 'auto' lets the implementation choose the best method at runtime.

Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.

"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
ndim = max(in1.type.ndim, in2.type.ndim)

def _pad_input(input_tensor, pad_width):
if boundary == "fill":
return pad(
input_tensor,
pad_width=pad_width,
mode="constant",
constant_values=fillvalue,
)
if boundary == "wrap":
return pad(input_tensor, pad_width=pad_width, mode="wrap")
if boundary == "symm":
return pad(input_tensor, pad_width=pad_width, mode="symmetric")
raise ValueError(f"Unsupported boundary mode: {boundary}")

if mode == "same":
# Same mode is implemented as "valid" with a padded input.
pad_width = zeros((ndim, 2), dtype="int64")
pad_width = pad_width[-2, 0].set(in2.shape[-2] // 2)
pad_width = pad_width[-2, 1].set((in2.shape[-2] - 1) // 2)
pad_width = pad_width[-1, 0].set(in2.shape[-1] // 2)
pad_width = pad_width[-1, 1].set((in2.shape[-1] - 1) // 2)
in1 = _pad_input(in1, pad_width)
mode = "valid"

if mode != "valid" and (boundary != "fill" or fillvalue != 0):
# We use a valid convolution on an appropriately padded kernel
*_, k, l = in2.shape

pad_width = zeros((ndim, 2), dtype="int64")
pad_width = pad_width[-2, :].set(k - 1)
pad_width = pad_width[-1, :].set(l - 1)
in1 = _pad_input(in1, pad_width)

mode = "valid"

full_mode = as_scalar(np.bool_(mode == "full"))
return type_cast(
TensorVariable, Blockwise(Convolve2d(method=method))(in1, in2, full_mode)
)
Loading
Loading