diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index c88f78c92..301ab7341 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -61,7 +61,6 @@ jobs:
displayName: 'Build sdist'
- script: |
- python -m pip install mypy==0.910
python -m mypy thinc
displayName: 'Run mypy'
diff --git a/requirements.txt b/requirements.txt
index 1d8b9491f..4a31fc017 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,7 +8,7 @@ wasabi>=0.8.1,<1.1.0
catalogue>=2.0.4,<2.1.0
ml_datasets>=0.2.0,<0.3.0
# Third-party dependencies
-pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
+pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
numpy>=1.15.0
# Backports of modern Python features
dataclasses>=0.6,<1.0; python_version < "3.7"
@@ -22,8 +22,7 @@ pytest-cov>=2.7.0,<2.8.0
coverage>=5.0.0,<6.0.0
mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0
-# restricting mypy until faster 3.10 wheels are available
-mypy>=0.901,<0.920; python_version < "3.10"
+mypy>=0.901,<0.960
types-mock>=0.1.1
types-contextvars>=0.1.2; python_version < "3.7"
types-dataclasses>=0.1.3; python_version < "3.7"
diff --git a/setup.cfg b/setup.cfg
index 59b69419a..4e535d899 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -31,7 +31,7 @@ python_requires = >=3.6
setup_requires =
cython>=0.25,<3.0
numpy>=1.15.0
- # We also need our Cython packages here to compile against
+ # We also need our Cython packages here to compile against
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=1.0.2,<1.1.0
@@ -48,7 +48,7 @@ install_requires =
# Third-party dependencies
setuptools
numpy>=1.15.0
- pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
+ pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
# Backports of modern Python features
dataclasses>=0.6,<1.0; python_version < "3.7"
typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8"
diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py
index 1d1a374f5..315b0b0bf 100644
--- a/thinc/backends/ops.py
+++ b/thinc/backends/ops.py
@@ -5,9 +5,9 @@
import numpy
import itertools
-from .. import registry
from ..types import Xp, Shape, DTypes, DTypesInt, DTypesFloat, List2d, ArrayXd
-from ..types import Array3d, Floats1d, Floats2d, Floats3d, Floats4d
+from ..types import Floats1d, Floats2d, Floats3d, Floats4d
+from ..types import Array1d, Array2d, Array3d, Array4d, ListXd
from ..types import FloatsXd, Ints1d, Ints2d, Ints3d, Ints4d, IntsXd, _Floats
from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator
from ..util import get_array_module, is_xp_array, to_numpy
@@ -135,13 +135,11 @@ def _get_batch(self, sequence, indices):
if isinstance(sequence, list):
subseq = [sequence[i] for i in indices]
elif isinstance(sequence, tuple):
- subseq = tuple(sequence[i] for i in indices) # type: ignore
+ subseq = tuple(sequence[i] for i in indices)
else:
- subseq = sequence[indices] # type: ignore
+ subseq = sequence[indices]
if is_xp_array(subseq):
- subseq = self.as_contig(
- cast(ArrayXd, self.xp.asarray(subseq))
- ) # type: ignore
+ subseq = self.as_contig(self.xp.asarray(subseq))
return subseq
def _get_batch_sizes(self, length: int, sizes: Iterator[int]):
@@ -225,13 +223,65 @@ def affine(self, X: Floats2d, W: Floats2d, b: Floats1d) -> Floats2d:
Y += b
return Y
+ @overload
def flatten(
self,
- X: Sequence[ArrayT],
+ X: List[Floats2d],
dtype: Optional[DTypes] = None,
pad: int = 0,
ndim_if_empty: int = 2,
- ) -> ArrayT:
+ ) -> Floats2d:
+ ...
+
+ @overload
+ def flatten(
+ self,
+ X: List[Ints1d],
+ dtype: Optional[DTypes] = None,
+ pad: int = 0,
+ ndim_if_empty: int = 2,
+ ) -> Ints1d:
+ ...
+
+ @overload
+ def flatten(
+ self,
+ X: List2d,
+ dtype: Optional[DTypes] = None,
+ pad: int = 0,
+ ndim_if_empty: int = 2,
+ ) -> Array2d:
+ ...
+
+ # further specific typed signatures can be added as necessary
+
+ @overload
+ def flatten(
+ self,
+ X: ListXd,
+ dtype: Optional[DTypes] = None,
+ pad: int = 0,
+ ndim_if_empty: int = 2,
+ ) -> ArrayXd:
+ ...
+
+ @overload
+ def flatten(
+ self,
+ X: Sequence[ArrayXd],
+ dtype: Optional[DTypes] = None,
+ pad: int = 0,
+ ndim_if_empty: int = 2,
+ ) -> ArrayXd:
+ ...
+
+ def flatten(
+ self,
+ X: Sequence[ArrayXd],
+ dtype: Optional[DTypes] = None,
+ pad: int = 0,
+ ndim_if_empty: int = 2,
+ ) -> ArrayXd:
"""Flatten a list of arrays into one large array."""
if X is None or len(X) == 0:
return self.alloc((0,) * ndim_if_empty, dtype=dtype or "f")
@@ -252,7 +302,25 @@ def flatten(
result = xp.asarray(result, dtype=dtype)
return result
+ @overload
def unflatten(self, X: Floats2d, lengths: Ints1d, pad: int = 0) -> List[Floats2d]:
+ ...
+
+ @overload
+ def unflatten(self, X: Ints1d, lengths: Ints1d, pad: int = 0) -> List[Ints1d]:
+ ...
+
+ @overload
+ def unflatten(self, X: Array2d, lengths: Ints1d, pad: int = 0) -> List2d:
+ ...
+
+ # further specific typed signatures can be added as necessary
+
+ @overload
+ def unflatten(self, X: ArrayXd, lengths: Ints1d, pad: int = 0) -> ListXd:
+ ...
+
+ def unflatten(self, X: ArrayXd, lengths: Ints1d, pad: int = 0) -> ListXd:
"""The reverse/backward operation of the `flatten` function: unflatten
a large array into a list of arrays according to the given lengths.
"""
@@ -302,7 +370,7 @@ def pad( # noqa: F811
output: Array3d = self.alloc(final_shape, dtype=seqs[0].dtype)
for i, arr in enumerate(seqs):
# It's difficult to convince this that the dtypes will match.
- output[i, : arr.shape[0]] = arr # type: ignore
+ output[i, : arr.shape[0]] = arr # type: ignore[assignment, call-overload]
return output
def unpad(self, padded: Array3d, lengths: List[int]) -> List2d:
@@ -314,14 +382,14 @@ def unpad(self, padded: Array3d, lengths: List[int]) -> List2d:
output.append(padded[i, :length])
return cast(List2d, output)
- def list2padded(self, seqs: List[Floats2d]) -> Padded:
+ def list2padded(self, seqs: List2d) -> Padded:
"""Pack a sequence of 2d arrays into a Padded datatype."""
if not seqs:
return Padded(
self.alloc3f(0, 0, 0), self.alloc1i(0), self.alloc1i(0), self.alloc1i(0)
)
elif len(seqs) == 1:
- data = self.reshape3f(seqs[0], seqs[0].shape[0], 1, seqs[0].shape[1])
+ data = self.reshape3(seqs[0], seqs[0].shape[0], 1, seqs[0].shape[1])
size_at_t = self.asarray1i([1] * data.shape[0])
lengths = self.asarray1i([data.shape[0]])
indices = self.asarray1i([0])
@@ -336,8 +404,8 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded:
# Reorder the sequences, by length. This looks the same in either
# direction: you're swapping elements between their original and sorted
# position.
- seqs = [seqs[i] for i in indices_]
- arr: Floats3d = self.pad(seqs)
+ seqs = cast(List2d, [seqs[i] for i in indices_])
+ arr: Array3d = self.pad(seqs)
assert arr.shape == (nB, nS, nO), (nB, nS, nO)
arr = self.as_contig(arr.transpose((1, 0, 2)))
assert arr.shape == (nS, nB, nO)
@@ -350,7 +418,7 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded:
batch_size_at_t_[t] = current_size
assert sum(lengths_) == sum(batch_size_at_t_)
return Padded(
- cast(Floats3d, arr),
+ arr,
self.asarray1i(batch_size_at_t_),
self.asarray1i(lengths_),
self.asarray1i(indices_),
@@ -361,7 +429,7 @@ def padded2list(self, padded: Padded) -> List2d:
data = padded.data
indices = to_numpy(padded.indices)
lengths = to_numpy(padded.lengths)
- unpadded: List[Optional[Floats2d]] = [None] * len(lengths)
+ unpadded: List[Optional[Array2d]] = [None] * len(lengths)
# Transpose from (length, batch, data) to (batch, length, data)
data = self.as_contig(data.transpose((1, 0, 2)))
for i in range(data.shape[0]):
@@ -500,6 +568,18 @@ def alloc(
else:
return self.xp.empty(shape, dtype=dtype)
+ def reshape1(self, array: ArrayXd, d0: int) -> Array1d:
+ return cast(Array1d, self.reshape(array, (d0,)))
+
+ def reshape2(self, array: ArrayXd, d0: int, d1: int) -> Array2d:
+ return cast(Array2d, self.reshape(array, (d0, d1)))
+
+ def reshape3(self, array: ArrayXd, d0: int, d1: int, d2: int) -> Array3d:
+ return cast(Array3d, self.reshape(array, (d0, d1, d2)))
+
+ def reshape4(self, array: ArrayXd, d0: int, d1: int, d2: int, d3: int) -> Array4d:
+ return cast(Array4d, self.reshape(array, (d0, d1, d2, d3)))
+
def reshape1f(self, array: FloatsXd, d0: int) -> Floats1d:
return cast(Floats1d, self.reshape(array, (d0,)))
@@ -619,7 +699,7 @@ def asarray(
return self.xp.asarray(data, dtype=dtype)
elif hasattr(data, "numpy"):
# Handles PyTorch Tensor
- return data.numpy() # type: ignore
+ return data.numpy() # type: ignore[union-attr]
elif dtype is not None:
return self.xp.array(data, dtype=dtype)
else:
@@ -641,8 +721,8 @@ def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType:
if inplace:
self.xp.exp(-X, out=X)
- X += 1.0 # type: ignore
- X **= -1.0 # type: ignore
+ X += 1.0 # type: ignore[assignment]
+ X **= -1.0 # type: ignore[assignment]
return cast(FloatsType, X)
else:
return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X)))
@@ -786,10 +866,10 @@ def clipped_linear(
inplace: bool = False,
) -> FloatsType:
if inplace:
- X *= slope # type: ignore
- X += offset # type: ignore
+ X *= slope # type: ignore[assignment]
+ X += offset # type: ignore[assignment]
return cast(FloatsType, self.xp.clip(X, min_val, max_val, out=X))
- out = X * slope + offset # type: ignore
+ out = X * slope + offset # type: ignore[assignment]
return cast(FloatsType, self.xp.clip(out, min_val, max_val))
def backprop_clipped_linear(
@@ -840,27 +920,27 @@ def backprop_hard_tanh(
def swish(self, X: FloatsType, inplace: bool = False) -> FloatsType:
if inplace:
- X *= self.sigmoid(X) # type: ignore
+ X *= self.sigmoid(X) # type: ignore[operator, assignment]
return cast(FloatsType, X)
- out = X * self.sigmoid(X) # type: ignore
+ out = X * self.sigmoid(X) # type: ignore[operator]
return cast(FloatsType, out)
def backprop_swish(
self, dY: FloatsType, X: FloatsType, Y: FloatsType, inplace: bool = False
) -> FloatsType:
- Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore
+ Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore[operator]
if inplace:
- dY *= Y # type: ignore
+ dY *= Y # type: ignore[operator, assignment]
return cast(FloatsType, dY)
- out = dY * Y # type: ignore
+ out = dY * Y # type: ignore[operator]
return cast(FloatsType, out)
# Following https://www.scitepress.org/Papers/2019/74696/74696.pdf
def hard_swish(self, X: FloatsType, inplace: bool = False) -> FloatsType:
if inplace:
- X *= self.hard_sigmoid(X) # type: ignore
+ X *= self.hard_sigmoid(X) # type: ignore[operator, assignment]
return cast(FloatsType, X)
- out = X * self.hard_sigmoid(X) # type: ignore
+ out = X * self.hard_sigmoid(X) # type: ignore[operator]
return cast(FloatsType, out)
def backprop_hard_swish(
@@ -927,7 +1007,7 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType:
else:
Y = self.xp.array(X)
Y *= tmp
- return cast(FloatsType, Y)
+ return Y
def backprop_gelu_approx(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
@@ -949,15 +1029,15 @@ def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType:
# GELU(x) = x · Φ(x)
cdf = gaussian_cdf(self, X)
if inplace:
- X *= cdf # type: ignore
+ X *= cdf # type: ignore[operator, assignment]
return X
- return X * cdf # type: ignore
+ return X * cdf # type: ignore[operator, return-value]
def backprop_gelu(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
# GELU'(x) = Φ(x) + x · PDF(x)
- dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore
+ dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore[operator]
if inplace:
dY *= dX
return dY
@@ -1239,8 +1319,8 @@ def lstm_forward_training(
for d in range(dirs):
# The inits are shaped (depth, dirs, nO). We add the internal dimension
# to make them set correctly.
- Yt2 = h_init[i, d].reshape((1, nO)) # type: ignore
- Ct2 = c_init[i, d].reshape((1, nO)) # type: ignore
+ Yt2 = h_init[i, d].reshape((1, nO)) # type: ignore[assignment]
+ Ct2 = c_init[i, d].reshape((1, nO)) # type: ignore[assignment]
layer_params, params_i = _split_weights(params, i, nO, nI, params_i)
Wx, Wh, bias = _transpose_weights(layer_params)
G[i, d] += xp.dot(X, Wx.T)
diff --git a/thinc/config.py b/thinc/config.py
index 167f8fd97..837f91b76 100644
--- a/thinc/config.py
+++ b/thinc/config.py
@@ -1,4 +1,4 @@
-from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type
+from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping
from typing import Iterable, Sequence, cast
from types import GeneratorType
from dataclasses import dataclass
@@ -550,7 +550,7 @@ def __init__(
self,
*,
config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None,
- errors: Iterable[Dict[str, Any]] = tuple(),
+ errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(),
title: Optional[str] = "Config validation error",
desc: Optional[str] = None,
parent: Optional[str] = None,
@@ -560,9 +560,10 @@ def __init__(
config (Union[Config, Dict[str, Dict[str, Any]], str]): The
config the validation error refers to.
- errors (Iterable[Dict[str, Any]]): A list of errors as dicts with keys
- "loc" (list of strings describing the path of the value), "msg"
- (validation message to show) and optional "type" (mostly internals).
+ errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]):
+ A list of errors as dicts with keys "loc" (list of strings
+ describing the path of the value), "msg" (validation message
+ to show) and optional "type" (mostly internals).
Same format as produced by pydantic's validation error (e.errors()).
title (str): The error title.
desc (str): Optional error description, displayed below the title.
diff --git a/thinc/layers/array_getitem.py b/thinc/layers/array_getitem.py
index 87a62e9b8..17ffcb7ee 100644
--- a/thinc/layers/array_getitem.py
+++ b/thinc/layers/array_getitem.py
@@ -1,13 +1,14 @@
-from typing import Union, Sequence, Tuple
+from typing import Union, Sequence, Tuple, TypeVar
from ..types import ArrayXd, FloatsXd, IntsXd
from ..model import Model
AxisIndex = Union[int, slice, Sequence[int]]
Index = Union[AxisIndex, Tuple[AxisIndex, ...]]
+ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd)
-def array_getitem(index: Index) -> Model[ArrayXd, ArrayXd]:
+def array_getitem(index: Index) -> Model[ArrayTXd, ArrayTXd]:
"""Index into input arrays, and return the subarrays.
index:
diff --git a/thinc/layers/cauchysimilarity.py b/thinc/layers/cauchysimilarity.py
index 89c70078b..25af8d9df 100644
--- a/thinc/layers/cauchysimilarity.py
+++ b/thinc/layers/cauchysimilarity.py
@@ -23,14 +23,15 @@ def CauchySimilarity(nI: Optional[int] = None) -> Model[InT, OutT]:
params={"W": None},
)
+
def forward(
model: Model[InT, OutT], X1_X2: InT, is_train: bool
) -> Tuple[OutT, Callable]:
X1, X2 = X1_X2
W = cast(Floats2d, model.get_param("W"))
diff = X1 - X2
- square_diff = diff ** 2
- total = (W * square_diff).sum(axis=1) # type: ignore
+ square_diff = diff**2
+ total = (W * square_diff).sum(axis=1)
sim, bp_sim = inverse(total)
def backprop(d_sim: OutT) -> InT:
diff --git a/thinc/layers/chain.py b/thinc/layers/chain.py
index 324319b61..258ee0902 100644
--- a/thinc/layers/chain.py
+++ b/thinc/layers/chain.py
@@ -1,4 +1,4 @@
-from typing import Tuple, Callable, Optional, TypeVar, Any, Dict
+from typing import Tuple, Callable, Optional, TypeVar, Any, Dict, List, cast
from ..model import Model
from ..config import registry
@@ -7,9 +7,8 @@
InT = TypeVar("InT")
-OutT = TypeVar("OutT")
MidT = TypeVar("MidT")
-
+OutT = TypeVar("OutT")
# Keep this function so we can provide variable arguments via the config
@registry.layers("chain.v1")
@@ -18,29 +17,31 @@ def chain_no_types(*layer: Model) -> Model:
def chain(
- layer1: Model[InT, MidT], layer2: Model[MidT, OutT], *layers: Model
+ layer1: Model[InT, MidT], layer2: Model[MidT, Any], *layers: Model[Any, Any]
) -> Model[InT, XY_YZ_OutT]:
"""Compose two models `f` and `g` such that they become layers of a single
feed-forward model that computes `g(f(x))`.
Also supports chaining more than 2 layers.
+ Note that the type checking for additional layers is carried out by the Thinc Mypy plugin.
"""
- layers = (layer1, layer2) + layers
+ all_layers: List[Model[Any, Any]] = [layer1, layer2]
+ all_layers.extend(layers)
dims: Dict[str, Optional[int]] = {"nO": None}
# set input dimension only if first layer has one - should be "False" otherwise
- if layers[0].has_dim("nI") is True:
- dims["nI"] = layers[0].get_dim("nI")
- if layers[0].has_dim("nI") is None:
+ if all_layers[0].has_dim("nI") is True:
+ dims["nI"] = all_layers[0].get_dim("nI")
+ if all_layers[0].has_dim("nI") is None:
dims["nI"] = None
# set output dimension according to last layer
- if layers[-1].has_dim("nO") is True:
- dims["nO"] = layers[-1].get_dim("nO")
+ if all_layers[-1].has_dim("nO") is True:
+ dims["nO"] = all_layers[-1].get_dim("nO")
- model: Model[InT, Any] = Model(
- ">>".join(layer.name for layer in layers),
+ model: Model[InT, XY_YZ_OutT] = Model(
+ ">>".join(layer.name for layer in all_layers),
forward,
init=init,
dims=dims,
- layers=layers,
+ layers=all_layers,
)
return model
@@ -65,7 +66,9 @@ def backprop(dY: OutT) -> InT:
def init(
- model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
+ model: Model[InT, OutT],
+ X: Optional[InT] = None,
+ Y: Optional[OutT] = None,
) -> None:
if X is None and Y is None:
for layer in model.layers:
@@ -92,10 +95,9 @@ def init(
model.set_dim("nI", model.layers[0].get_dim("nI"))
if model.has_dim("nO") is None:
try:
- nO = get_width(curr_input) # type: ignore
+ nO = get_width(curr_input) # type: ignore[arg-type]
+ model.set_dim("nO", nO)
except ValueError:
if model.layers[-1].has_dim("nO"):
nO = model.layers[-1].get_dim("nO")
- else:
- nO = None # type: ignore
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO)
diff --git a/thinc/layers/concatenate.py b/thinc/layers/concatenate.py
index c9faefd63..78e4c558b 100644
--- a/thinc/layers/concatenate.py
+++ b/thinc/layers/concatenate.py
@@ -1,5 +1,5 @@
-from typing import Any, List, Tuple, Callable, Optional, TypeVar, cast, Dict, Union
-
+from typing import Any, List, Tuple, Callable, Optional
+from typing import TypeVar, cast, Dict, Union, Sequence
from ..model import Model
from ..config import registry
from ..types import Array2d, Ragged
@@ -9,7 +9,7 @@
InT = TypeVar("InT", bound=Any)
-OutT = TypeVar("OutT", bound=Union[Array2d, List[Array2d], Ragged])
+OutT = TypeVar("OutT", bound=Union[Array2d, Sequence[Array2d], Ragged])
@registry.layers("concatenate.v1")
@@ -43,15 +43,18 @@ def concatenate(*layers: Model) -> Model[InT, XY_XY_OutT]:
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
Ys, callbacks = zip(*[layer(X, is_train=is_train) for layer in model.layers])
if isinstance(Ys[0], list):
- return _list_forward(model, X, Ys, callbacks, is_train) # type: ignore
+ data_l, backprop = _list_forward(model, X, Ys, callbacks, is_train)
+ return cast(OutT, data_l), backprop
elif isinstance(Ys[0], Ragged):
- return _ragged_forward(model, X, Ys, callbacks, is_train) # type: ignore
+ data_r, backprop = _ragged_forward(model, X, Ys, callbacks, is_train)
+ return cast(OutT, data_r), backprop
else:
- return _array_forward(model, X, Ys, callbacks, is_train) # type: ignore
+ data_a, backprop = _array_forward(model, X, Ys, callbacks, is_train)
+ return cast(OutT, data_a), backprop
def _array_forward(
- model: Model[InT, Array2d], X, Ys, callbacks, is_train: bool
+ model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool
) -> Tuple[Array2d, Callable]:
widths = [Y.shape[1] for Y in Ys]
output = model.ops.xp.hstack(Ys)
@@ -61,7 +64,9 @@ def backprop(d_output: Array2d) -> InT:
dX = callbacks[0](dY)
start = widths[0]
add_gradients = hasattr(dX, "__add__") or hasattr(dX, "__iadd__")
- add_gradients_data = hasattr(dX, "data") and (hasattr(dX.data, "__add__") or hasattr(dX.data, "__iadd__"))
+ add_gradients_data = hasattr(dX, "data") and (
+ hasattr(dX.data, "__add__") or hasattr(dX.data, "__iadd__")
+ )
for bwd, width in zip(callbacks[1:], widths[1:]):
dY = model.ops.as_contig(d_output[:, start : start + width])
gradient = bwd(dY)
@@ -76,7 +81,7 @@ def backprop(d_output: Array2d) -> InT:
def _ragged_forward(
- model: Model[InT, Ragged], X, Ys, callbacks, is_train: bool
+ model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool
) -> Tuple[Ragged, Callable]:
widths = [Y.dataXd.shape[1] for Y in Ys]
@@ -98,28 +103,28 @@ def backprop(d_output: Ragged) -> InT:
return output, backprop
-def _list_forward(model: Model[InT, List[Array2d]], X, Ys, callbacks, is_train: bool):
- lengths = model.ops.asarray1i([len(x) for x in X])
- Ys = [model.ops.xp.concatenate(Y, axis=0) for Y in Ys]
- widths = [Y.shape[1] for Y in Ys]
- out_array = model.ops.xp.hstack(Ys)
- output = model.ops.unflatten(out_array, lengths)
-
- def backprop(d_output: List[Array2d]) -> InT:
+def _list_forward(
+ model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool
+) -> Tuple[Sequence[Array2d], Callable]:
+ def backprop(d_output: Sequence[Array2d]) -> InT:
d_out_array = model.ops.xp.concatenate(d_output, axis=0)
dY = model.ops.as_contig(d_out_array[:, : widths[0]])
# We want to generalize unflatten later.
- dY = model.ops.unflatten(dY, lengths) # type: ignore
+ dY = model.ops.unflatten(dY, lengths)
dX = callbacks[0](dY)
start = widths[0]
for bwd, width in zip(callbacks[1:], widths[1:]):
dY = model.ops.as_contig(d_out_array[:, start : start + width])
- dY = model.ops.unflatten(dY, lengths) # type: ignore
+ dY = model.ops.unflatten(dY, lengths)
dX += bwd(dY)
start += width
return dX
- return output, backprop
+ lengths = model.ops.asarray1i([len(x) for x in X])
+ Ys = [model.ops.xp.concatenate(Y, axis=0) for Y in Ys]
+ widths = [Y.shape[1] for Y in Ys]
+ out_array = model.ops.xp.hstack(Ys)
+ return model.ops.unflatten(out_array, lengths), backprop
def init(
diff --git a/thinc/layers/dropout.py b/thinc/layers/dropout.py
index fedc3310b..f4fa29445 100644
--- a/thinc/layers/dropout.py
+++ b/thinc/layers/dropout.py
@@ -1,12 +1,11 @@
-from typing import Tuple, Callable, List, TypeVar, Any
+from typing import Tuple, Callable, List, TypeVar, cast, Union, Sequence
from ..model import Model
from ..config import registry
from ..types import ArrayXd, Ragged, Padded
-InT = TypeVar("InT")
-ArrayT = TypeVar("ArrayT", bound=ArrayXd)
+InT = TypeVar("InT", bound=Union[ArrayXd, Sequence[ArrayXd], Ragged, Padded])
@registry.layers("Dropout.v1")
@@ -18,40 +17,39 @@ def Dropout(rate: float = 0.0) -> Model[InT, InT]:
return Model("dropout", forward, attrs={"dropout_rate": rate, "is_enabled": True})
-# We're getting type hell here, I think because of the instance checks?
-# It's sort of painful, because I think this confused the types of other
-# layers that are trying to use dropout.
-# I've relaxed the types for now, but it'd be good to understand what's wrong
-# here.
-def forward(model: Model, X, is_train: bool) -> Tuple[Any, Callable]:
+def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
rate = model.attrs["dropout_rate"]
is_enabled = model.attrs["is_enabled"] and is_train
if rate == 0 or not is_enabled:
return X, lambda dY: dY
elif isinstance(X, Ragged):
- return _dropout_ragged(model, X, is_train)
+ data_r, backprop = _dropout_ragged(model, X, is_train)
+ return cast(InT, data_r), backprop
elif isinstance(X, Padded):
- return _dropout_padded(model, X, is_train)
- elif isinstance(X, list):
- return _dropout_lists(model, X, is_train)
+ data_p, backprop = _dropout_padded(model, X, is_train)
+ return cast(InT, data_p), backprop
+ elif isinstance(X, Sequence):
+ data_l, backprop = _dropout_lists(model, X, is_train)
+ return cast(InT, data_l), backprop
else:
- return _dropout_array(model, X, is_train)
+ data_a, backprop = _dropout_array(model, cast(ArrayXd, X), is_train)
+ return cast(InT, data_a), backprop
def _dropout_array(
- model: Model[ArrayT, ArrayT], X: ArrayT, is_train: bool
-) -> Tuple[ArrayT, Callable]:
+ model: Model[InT, InT], X: ArrayXd, is_train: bool
+) -> Tuple[ArrayXd, Callable]:
rate = model.attrs["dropout_rate"]
mask = model.ops.get_dropout_mask(X.shape, rate)
- def backprop(dY: ArrayT) -> ArrayT:
+ def backprop(dY: ArrayXd) -> ArrayXd:
return dY * mask
- return X * mask, backprop
+ return cast(ArrayXd, X * mask), backprop
def _dropout_padded(
- model: Model, Xp: Padded, is_train: bool
+ model: Model[InT, InT], Xp: Padded, is_train: bool
) -> Tuple[Padded, Callable]:
X = Xp.data
mask = model.ops.get_dropout_mask(X.shape, model.attrs["dropout_rate"])
@@ -64,7 +62,7 @@ def backprop(dYp: Padded) -> Padded:
def _dropout_ragged(
- model: Model, Xr: Ragged, is_train: bool
+ model: Model[InT, InT], Xr: Ragged, is_train: bool
) -> Tuple[Ragged, Callable]:
X = Xr.data
lengths = Xr.lengths
@@ -78,13 +76,13 @@ def backprop(dYr: Ragged) -> Ragged:
def _dropout_lists(
- model: Model[ArrayT, ArrayT], Xs: List[ArrayT], is_train: bool
-) -> Tuple[List[ArrayT], Callable]:
+ model: Model[InT, InT], Xs: Sequence[ArrayXd], is_train: bool
+) -> Tuple[Sequence[ArrayXd], Callable]:
rate = model.attrs["dropout_rate"]
masks = [model.ops.get_dropout_mask(X.shape, rate) for X in Xs]
Ys = [X * mask for X, mask in zip(Xs, masks)]
- def backprop(dYs: List[ArrayT]) -> List[ArrayT]:
+ def backprop(dYs: List[ArrayXd]) -> List[ArrayXd]:
return [dY * mask for dY, mask in zip(dYs, masks)]
return Ys, backprop
diff --git a/thinc/layers/embed.py b/thinc/layers/embed.py
index 80b25266d..9e3587460 100644
--- a/thinc/layers/embed.py
+++ b/thinc/layers/embed.py
@@ -1,4 +1,4 @@
-from typing import Dict, Callable, Tuple, Optional, Union, cast
+from typing import Dict, Callable, Tuple, Optional, Union, cast, TypeVar
from .chain import chain
from .array_getitem import ints_getitem
@@ -9,7 +9,7 @@
from ..util import get_width, partial
-InT = Union[Ints1d, Ints2d]
+InT = TypeVar("InT", bound=Union[Ints1d, Ints2d])
OutT = Floats2d
@@ -26,7 +26,7 @@ def Embed(
attrs: Dict[str, Union[None, int, float]] = {}
if dropout is not None:
attrs["dropout_rate"] = dropout
- model = Model( # type: ignore
+ model: Model = Model(
"embed",
forward,
init=partial(init, initializer),
@@ -45,7 +45,7 @@ def Embed(
def forward(
- model: Model[InT, OutT], ids: Ints1d, is_train: bool
+ model: Model[Ints1d, OutT], ids: Ints1d, is_train: bool
) -> Tuple[OutT, Callable]:
vectors = cast(Floats2d, model.get_param("E"))
nO = vectors.shape[1]
@@ -72,7 +72,7 @@ def backprop(d_output: OutT) -> Ints1d:
def init(
initializer: Callable,
- model: Model[InT, OutT],
+ model: Model[Ints1d, OutT],
X: Optional[Ints1d] = None,
Y: Optional[OutT] = None,
) -> None:
diff --git a/thinc/layers/expand_window.py b/thinc/layers/expand_window.py
index 15b4534f5..1075a49a2 100644
--- a/thinc/layers/expand_window.py
+++ b/thinc/layers/expand_window.py
@@ -24,7 +24,7 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab
def _expand_window_floats(
- model: Model[Floats2d, Floats2d], X: Floats2d
+ model: Model[InT, InT], X: Floats2d
) -> Tuple[Floats2d, Callable]:
nW = model.attrs["window_size"]
if len(X) > 0:
@@ -40,7 +40,7 @@ def backprop(dY: Floats2d) -> Floats2d:
def _expand_window_ragged(
- model: Model[Ragged, Ragged], Xr: Ragged
+ model: Model[InT, InT], Xr: Ragged
) -> Tuple[Ragged, Callable]:
nW = model.attrs["window_size"]
Y = Ragged(
diff --git a/thinc/layers/hashembed.py b/thinc/layers/hashembed.py
index 1830aff93..74b85c7cf 100644
--- a/thinc/layers/hashembed.py
+++ b/thinc/layers/hashembed.py
@@ -1,4 +1,4 @@
-from typing import Callable, Dict, Tuple, Optional, Any, Union, cast
+from typing import Callable, Dict, Tuple, Optional, Any, Union, cast, TypeVar
from .chain import chain
from .array_getitem import ints_getitem
@@ -9,7 +9,7 @@
from ..util import partial
-InT = Union[Ints2d, Ints1d]
+InT = TypeVar("InT", bound=Union[Ints1d, Ints2d])
OutT = Floats2d
@@ -35,7 +35,7 @@ def HashEmbed(
attrs: Dict[str, Any] = {"column": column, "seed": seed}
if dropout is not None:
attrs["dropout_rate"] = dropout
- model = Model( # type: ignore
+ model: Model = Model(
"hashembed",
forward,
init=partial(init, initializer),
@@ -56,7 +56,7 @@ def HashEmbed(
def forward(
- model: Model[InT, OutT], ids: Ints1d, is_train: bool
+ model: Model[Ints1d, OutT], ids: Ints1d, is_train: bool
) -> Tuple[OutT, Callable]:
vectors = cast(Floats2d, model.get_param("E"))
nV = vectors.shape[0]
@@ -64,7 +64,7 @@ def forward(
if len(ids) == 0:
output: Floats2d = model.ops.alloc((0, nO), dtype=vectors.dtype)
else:
- ids = model.ops.as_contig(ids, dtype="uint64") # type: ignore
+ ids = model.ops.as_contig(ids, dtype="uint64")
nN = ids.shape[0]
seed: int = model.attrs["seed"]
keys = model.ops.hash(ids, seed) % nV
@@ -92,7 +92,7 @@ def backprop(d_vectors: OutT) -> Ints1d:
def init(
initializer: Callable,
- model: Model[InT, OutT],
+ model: Model[Ints1d, OutT],
X: Optional[Ints1d] = None,
Y: Optional[OutT] = None,
) -> None:
diff --git a/thinc/layers/list2array.py b/thinc/layers/list2array.py
index 2b4b03141..fff5befc0 100644
--- a/thinc/layers/list2array.py
+++ b/thinc/layers/list2array.py
@@ -1,12 +1,12 @@
-from typing import Tuple, Callable
+from typing import Tuple, Callable, TypeVar, List, Union, cast
from ..model import Model
from ..config import registry
-from ..types import Array2d, List2d
+from ..types import Array2d
-InT = List2d
-OutT = Array2d
+OutT = TypeVar("OutT", bound=Array2d)
+InT = List[OutT]
@registry.layers("list2array.v1")
@@ -22,6 +22,6 @@ def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Cal
lengths = model.ops.asarray1i([len(x) for x in Xs])
def backprop(dY: OutT) -> InT:
- return model.ops.unflatten(dY, lengths) # type: ignore
+ return model.ops.unflatten(dY, lengths)
- return model.ops.flatten(Xs), backprop # type: ignore
+ return model.ops.flatten(Xs), backprop
diff --git a/thinc/layers/list2padded.py b/thinc/layers/list2padded.py
index 28aabace7..2a02f90e0 100644
--- a/thinc/layers/list2padded.py
+++ b/thinc/layers/list2padded.py
@@ -1,11 +1,11 @@
-from typing import Tuple, Callable
+from typing import Tuple, Callable, TypeVar, cast
from ..types import Padded, List2d
from ..model import Model
from ..config import registry
-InT = List2d
+InT = TypeVar("InT", bound=List2d)
OutT = Padded
@@ -16,9 +16,9 @@ def list2padded() -> Model[InT, OutT]:
def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]:
- Yp = model.ops.list2padded(Xs) # type: ignore
+ Yp = model.ops.list2padded(Xs)
def backprop(dYp: OutT) -> InT:
- return model.ops.padded2list(dYp) # type: ignore
+ return cast(InT, model.ops.padded2list(dYp))
return Yp, backprop
diff --git a/thinc/layers/list2ragged.py b/thinc/layers/list2ragged.py
index 7b293fa3d..a63237dfe 100644
--- a/thinc/layers/list2ragged.py
+++ b/thinc/layers/list2ragged.py
@@ -1,11 +1,11 @@
-from typing import Tuple, List, Callable
+from typing import Tuple, List, Callable, cast, TypeVar
from ..model import Model
from ..config import registry
-from ..types import ArrayXd, Ragged
+from ..types import ListXd, ArrayXd, Ragged
-InT = List[ArrayXd]
+InT = TypeVar("InT", bound=ListXd)
OutT = Ragged
@@ -20,7 +20,7 @@ def list2ragged() -> Model[InT, OutT]:
def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]:
def backprop(dYr: OutT) -> InT:
- return model.ops.unflatten(dYr.data, dYr.lengths) # type: ignore
+ return cast(InT, model.ops.unflatten(dYr.data, dYr.lengths))
lengths = model.ops.asarray1i([len(x) for x in Xs])
return Ragged(model.ops.flatten(Xs), lengths), backprop
diff --git a/thinc/layers/lstm.py b/thinc/layers/lstm.py
index 626d2f567..2acc70ac3 100644
--- a/thinc/layers/lstm.py
+++ b/thinc/layers/lstm.py
@@ -45,13 +45,13 @@ def PyTorchLSTM(
from .pytorchwrapper import PyTorchRNNWrapper
if depth == 0:
- return noop() # type: ignore
+ return noop() # type: ignore[misc]
nH = nO
if bi:
nH = nO // 2
pytorch_rnn = PyTorchRNNWrapper(
- torch.nn.LSTM(nI, nH, depth, bidirectional=bi, dropout=dropout)
- )
+ torch.nn.LSTM(nI, nH, depth, bidirectional=bi, dropout=dropout)
+ )
pytorch_rnn.set_dim("nO", nO)
pytorch_rnn.set_dim("nI", nI)
return with_padded(pytorch_rnn)
@@ -161,7 +161,7 @@ def _padded_to_packed(ops: Ops, Xp: Padded) -> Ragged:
start = 0
for t in range(Xp.size_at_t.shape[0]):
batch_size = Xp.size_at_t[t]
- Y[start : start + batch_size] = Xp.data[t, :batch_size]
+ Y[start : start + batch_size] = Xp.data[t, :batch_size] # type: ignore[assignment]
start += batch_size
return Ragged(Y, Xp.size_at_t)
diff --git a/thinc/layers/maxout.py b/thinc/layers/maxout.py
index 6104e0c0e..4f361f78d 100644
--- a/thinc/layers/maxout.py
+++ b/thinc/layers/maxout.py
@@ -54,7 +54,7 @@ def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Call
def backprop(d_best: OutT) -> InT:
dZ = model.ops.backprop_maxout(d_best, which, nP)
# TODO: Add sum methods for Floats3d
- model.inc_grad("b", dZ.sum(axis=0)) # type: ignore
+ model.inc_grad("b", dZ.sum(axis=0)) # type: ignore[call-overload]
dY = model.ops.reshape2f(dZ, dZ.shape[0], nO * nP)
dW = model.ops.reshape3f(model.ops.gemm(dY, X, trans1=True), nO, nP, nI)
model.inc_grad("W", dW)
diff --git a/thinc/layers/padded2list.py b/thinc/layers/padded2list.py
index 42a6e4c26..8f1bee7e8 100644
--- a/thinc/layers/padded2list.py
+++ b/thinc/layers/padded2list.py
@@ -1,4 +1,4 @@
-from typing import Tuple, Callable
+from typing import Tuple, Callable, TypeVar, cast
from ..types import Padded, List2d
from ..model import Model
@@ -6,7 +6,7 @@
InT = Padded
-OutT = List2d
+OutT = TypeVar("OutT", bound=List2d)
@registry.layers("padded2list.v1")
@@ -15,11 +15,13 @@ def padded2list() -> Model[InT, OutT]:
return Model(f"padded2list", forward)
-def forward(model: Model[InT, OutT], Xp: InT, is_train: bool) -> Tuple[OutT, Callable]:
- Ys = model.ops.padded2list(Xp) # type: ignore
+def forward(
+ model: Model[InT, OutT], Xp: InT, is_train: bool
+) -> Tuple[OutT, Callable[[OutT], InT]]:
+ Ys = cast(OutT, model.ops.padded2list(Xp))
def backprop(dYs: OutT) -> InT:
- dYp = model.ops.list2padded(dYs) # type: ignore
+ dYp = model.ops.list2padded(dYs)
assert isinstance(dYp, Padded)
return dYp
diff --git a/thinc/layers/ragged2list.py b/thinc/layers/ragged2list.py
index 2c5f321c5..35af28f2f 100644
--- a/thinc/layers/ragged2list.py
+++ b/thinc/layers/ragged2list.py
@@ -1,12 +1,12 @@
-from typing import Tuple, Callable
+from typing import Tuple, Callable, TypeVar, cast
from ..model import Model
from ..config import registry
-from ..types import Ragged, List2d
+from ..types import Ragged, ListXd
InT = Ragged
-OutT = List2d
+OutT = TypeVar("OutT", bound=ListXd)
@registry.layers("ragged2list.v1")
@@ -19,7 +19,8 @@ def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Cal
lengths = Xr.lengths
def backprop(dXs: OutT) -> InT:
- return Ragged(model.ops.flatten(dXs, pad=0), lengths) # type: ignore
+ return Ragged(model.ops.flatten(dXs, pad=0), lengths) # type:ignore[arg-type]
+ # type ignore necessary for older versions of Mypy/Pydantic
- data = model.ops.unflatten(Xr.dataXd, Xr.lengths) # type: ignore
+ data = cast(OutT, model.ops.unflatten(Xr.dataXd, Xr.lengths))
return data, backprop
diff --git a/thinc/layers/reduce_first.py b/thinc/layers/reduce_first.py
index bfc23f613..df7541315 100644
--- a/thinc/layers/reduce_first.py
+++ b/thinc/layers/reduce_first.py
@@ -7,17 +7,20 @@
OutT = TypeVar("OutT", bound=ArrayXd)
+
@registry.layers("reduce_first.v1")
def reduce_first() -> Model[Ragged, OutT]:
"""Reduce ragged-formatted sequences to their first element."""
return Model("reduce_first", forward)
-def forward(model: Model[Ragged, OutT], Xr: Ragged, is_train: bool) -> Tuple[OutT, Callable]:
+def forward(
+ model: Model[Ragged, OutT], Xr: Ragged, is_train: bool
+) -> Tuple[OutT, Callable[[OutT], Ragged]]:
starts = model.ops.alloc1i(Xr.lengths.shape[0])
starts[1:] += Xr.lengths.cumsum()[:-1]
- X = cast(OutT, Xr.dataXd)
- Y = cast(OutT, X[starts]) # type: ignore
+ X = Xr.dataXd
+ Y = cast(OutT, X[starts])
x_shape = Xr.dataXd.shape
lengths = Xr.lengths
@@ -25,8 +28,8 @@ def forward(model: Model[Ragged, OutT], Xr: Ragged, is_train: bool) -> Tuple[Out
def backprop(dY: OutT) -> Ragged:
array_info.check_consistency(dY)
- dX = cast(OutT, model.ops.alloc(x_shape, dtype=dY.dtype))
- dX[starts] = dY # type: ignore
+ dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype)
+ dX[starts] = dY # type: ignore[assignment]
return Ragged(dX, lengths)
return Y, backprop
diff --git a/thinc/layers/reduce_last.py b/thinc/layers/reduce_last.py
index 51bda3690..e45a65d12 100644
--- a/thinc/layers/reduce_last.py
+++ b/thinc/layers/reduce_last.py
@@ -7,23 +7,26 @@
OutT = TypeVar("OutT", bound=ArrayXd)
+
@registry.layers("reduce_last.v1")
def reduce_last() -> Model[Ragged, OutT]:
"""Reduce ragged-formatted sequences to their last element."""
return Model("reduce_last", forward)
-def forward(model: Model[Ragged, OutT], Xr: Ragged, is_train: bool) -> Tuple[OutT, Callable]:
+def forward(
+ model: Model[Ragged, OutT], Xr: Ragged, is_train: bool
+) -> Tuple[OutT, Callable[[OutT], Ragged]]:
ends = Xr.lengths.cumsum() - 1
- Y = cast(OutT, Xr.dataXd[ends]) # type: ignore
+ Y = cast(OutT, Xr.dataXd[ends])
x_shape = Xr.dataXd.shape
lengths = Xr.lengths
array_info = ArrayInfo.from_array(Y)
def backprop(dY: OutT) -> Ragged:
array_info.check_consistency(dY)
- dX = cast(OutT, model.ops.alloc(x_shape, dtype=dY.dtype))
- dX[ends] = dY # type: ignore
+ dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype)
+ dX[ends] = dY # type: ignore[assignment]
return Ragged(dX, lengths)
return Y, backprop
diff --git a/thinc/layers/residual.py b/thinc/layers/residual.py
index f4062cad1..3793ee1d5 100644
--- a/thinc/layers/residual.py
+++ b/thinc/layers/residual.py
@@ -4,9 +4,10 @@
from ..config import registry
from ..types import Floats1d, Floats2d, Floats3d, Floats4d, FloatsXd, Ragged, Padded
-
# fmt: off
-InT = TypeVar("InT", List[Floats1d], List[Floats2d], List[Floats3d], List[Floats4d], Ragged, Padded, FloatsXd)
+InT = TypeVar(
+ "InT", List[Floats1d], List[Floats2d], List[Floats3d], List[Floats4d],
+ Ragged, Padded, FloatsXd, Floats1d, Floats2d, Floats3d, Floats4d)
# fmt: on
diff --git a/thinc/layers/resizable.py b/thinc/layers/resizable.py
index 3459296ce..3454684d0 100644
--- a/thinc/layers/resizable.py
+++ b/thinc/layers/resizable.py
@@ -10,8 +10,7 @@
@registry.layers("resizable.v1")
def resizable(layer, resize_layer: Callable) -> Model[InT, OutT]:
- """Container that holds one layer that can change dimensions.
- """
+ """Container that holds one layer that can change dimensions."""
return Model(
f"resizable({layer.name})",
forward,
diff --git a/thinc/layers/siamese.py b/thinc/layers/siamese.py
index cc3c2ca71..82bafacbb 100644
--- a/thinc/layers/siamese.py
+++ b/thinc/layers/siamese.py
@@ -9,7 +9,7 @@
LayerT = TypeVar("LayerT")
SimT = TypeVar("SimT")
InT = Tuple[LayerT, LayerT]
-OutT = ArrayXd
+OutT = TypeVar("OutT", bound=ArrayXd)
@registry.layers("siamese.v1")
diff --git a/thinc/layers/sigmoid_activation.py b/thinc/layers/sigmoid_activation.py
index 02923b6eb..8b3982aea 100644
--- a/thinc/layers/sigmoid_activation.py
+++ b/thinc/layers/sigmoid_activation.py
@@ -17,6 +17,8 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab
Y = model.ops.sigmoid(X, inplace=False)
def backprop(dY: InT) -> InT:
- return dY * model.ops.dsigmoid(Y, inplace=False) # type: ignore
+ return cast(
+ InT, dY * model.ops.dsigmoid(Y, inplace=False) # type:ignore[operator]
+ )
return Y, backprop
diff --git a/thinc/layers/tuplify.py b/thinc/layers/tuplify.py
index b95082ddf..99b4d7589 100644
--- a/thinc/layers/tuplify.py
+++ b/thinc/layers/tuplify.py
@@ -1,15 +1,16 @@
-from typing import Callable, Optional, Tuple, Any, TypeVar
+from typing import Optional, Tuple, Any, TypeVar
from ..model import Model
from ..config import registry
InT = TypeVar("InT")
OutT = Tuple
-MidT = TypeVar("MidT")
@registry.layers("tuplify.v1")
-def tuplify(layer1: Model[InT, Any], layer2: Model[InT, Any], *layers) -> Model[InT, Tuple]:
+def tuplify(
+ layer1: Model[InT, Any], layer2: Model[InT, Any], *layers
+) -> Model[InT, Tuple]:
"""Send a separate copy of the input to each child layer, and join the
outputs of the children into a tuple on the way out.
diff --git a/thinc/layers/with_array.py b/thinc/layers/with_array.py
index 3d6d91833..3701fc8a3 100644
--- a/thinc/layers/with_array.py
+++ b/thinc/layers/with_array.py
@@ -2,20 +2,19 @@
from ..model import Model
from ..config import registry
-from ..types import Array2d, Floats2d, Padded, Ragged, ArrayXd, Floats3d
-from ..types import List2d
+from ..types import Padded, Ragged, ArrayXd, Array3d, ListXd
-
-SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, List2d, ArrayXd])
+ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd)
+SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, ListXd, ArrayXd])
@registry.layers("with_array.v1")
-def with_array(layer: Model[ArrayXd, ArrayXd], pad: int = 0) -> Model[SeqT, SeqT]:
- """Transform sequence data into a contiguous 2d array on the way into and
+def with_array(layer: Model[ArrayTXd, ArrayTXd], pad: int = 0) -> Model[SeqT, SeqT]:
+ """Transform sequence data into a contiguous array on the way into and
out of a model. Handles a variety of sequence types: lists, padded and ragged.
- If the input is a 2d array, it is passed through unchanged.
+ If the input is an array, it is passed through unchanged.
"""
- return Model(
+ model: Model[SeqT, SeqT] = Model(
f"with_array({layer.name})",
forward,
init=init,
@@ -23,21 +22,20 @@ def with_array(layer: Model[ArrayXd, ArrayXd], pad: int = 0) -> Model[SeqT, SeqT
attrs={"pad": pad},
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
)
+ return model
-def forward(model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool):
+def forward(
+ model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
+) -> Tuple[SeqT, Callable]:
if isinstance(Xseq, Ragged):
- return _ragged_forward(
- cast(Model[Ragged, Ragged], model), cast(Ragged, Xseq), is_train
- )
+ return cast(Tuple[SeqT, Callable], _ragged_forward(model, Xseq, is_train))
elif isinstance(Xseq, Padded):
- return _padded_forward(
- cast(Model[Padded, Padded], model), cast(Padded, Xseq), is_train
- )
+ return cast(Tuple[SeqT, Callable], _padded_forward(model, Xseq, is_train))
elif not isinstance(Xseq, (list, tuple)):
return model.layers[0](Xseq, is_train)
else:
- return _list_forward(cast(Model[List2d, List2d], model), Xseq, is_train)
+ return cast(Tuple[SeqT, Callable], _list_forward(model, Xseq, is_train))
def init(
@@ -66,16 +64,16 @@ def _get_array(model, X: SeqT) -> ArrayXd:
def _list_forward(
- model: Model[List2d, List2d], Xs: List2d, is_train: bool
-) -> Tuple[List2d, Callable]:
- layer = model.layers[0]
+ model: Model[SeqT, SeqT], Xs: ListXd, is_train: bool
+) -> Tuple[ListXd, Callable]:
+ layer: Model[ArrayXd, ArrayXd] = model.layers[0]
pad = model.attrs["pad"]
lengths = layer.ops.asarray1i([len(seq) for seq in Xs])
- Xf = layer.ops.flatten(Xs, pad=pad) # type: ignore
+ Xf = layer.ops.flatten(Xs, pad=pad)
Yf, get_dXf = layer(Xf, is_train)
- def backprop(dYs: List2d) -> List2d:
- dYf = layer.ops.flatten(dYs, pad=pad) # type: ignore
+ def backprop(dYs: ListXd) -> ListXd:
+ dYf = layer.ops.flatten(dYs, pad=pad)
dXf = get_dXf(dYf)
return layer.ops.unflatten(dXf, lengths, pad=pad)
@@ -83,7 +81,7 @@ def backprop(dYs: List2d) -> List2d:
def _ragged_forward(
- model: Model[Ragged, Ragged], Xr: Ragged, is_train: bool
+ model: Model[SeqT, SeqT], Xr: Ragged, is_train: bool
) -> Tuple[Ragged, Callable]:
layer: Model[ArrayXd, ArrayXd] = model.layers[0]
Y, get_dX = layer(Xr.dataXd, is_train)
@@ -95,9 +93,9 @@ def backprop(dYr: Ragged) -> Ragged:
def _padded_forward(
- model: Model[Padded, Padded], Xp: Padded, is_train: bool
+ model: Model[SeqT, SeqT], Xp: Padded, is_train: bool
) -> Tuple[Padded, Callable]:
- layer: Model[Floats3d, Floats3d] = model.layers[0]
+ layer: Model[Array3d, Array3d] = model.layers[0]
Y, get_dX = layer(Xp.data, is_train)
def backprop(dYp: Padded) -> Padded:
diff --git a/thinc/layers/with_array2d.py b/thinc/layers/with_array2d.py
index 15f4420d2..9f7de213c 100644
--- a/thinc/layers/with_array2d.py
+++ b/thinc/layers/with_array2d.py
@@ -1,9 +1,8 @@
-from typing import Tuple, Callable, Optional, TypeVar, Union, cast
+from typing import Tuple, Callable, Optional, TypeVar, cast, List, Union
from ..model import Model
from ..config import registry
-from ..types import Array2d, Floats2d, Padded, Ragged, ArrayXd
-from ..types import List2d
+from ..types import Array2d, Floats2d, List2d, Padded, Ragged
ValT = TypeVar("ValT", bound=Array2d)
@@ -26,19 +25,18 @@ def with_array2d(layer: Model[ValT, ValT], pad: int = 0) -> Model[SeqT, SeqT]:
)
-def forward(model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool):
+def forward(
+ model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
+) -> Tuple[SeqT, Callable]:
if isinstance(Xseq, Ragged):
- return _ragged_forward(
- cast(Model[Ragged, Ragged], model), cast(Ragged, Xseq), is_train
- )
+ return cast(Tuple[SeqT, Callable], _ragged_forward(model, Xseq, is_train))
elif isinstance(Xseq, Padded):
- return _padded_forward(
- cast(Model[Padded, Padded], model), cast(Padded, Xseq), is_train
- )
+ return cast(Tuple[SeqT, Callable], _padded_forward(model, Xseq, is_train))
elif not isinstance(Xseq, (list, tuple)):
return model.layers[0](Xseq, is_train)
else:
- return _list_forward(cast(Model[List2d, List2d], model), Xseq, is_train)
+ return cast(Tuple[SeqT, Callable], _list_forward(model, Xseq, is_train))
+ return
def init(
@@ -69,16 +67,16 @@ def _get_array(model, X: SeqT) -> Array2d:
def _list_forward(
- model: Model[List2d, List2d], Xs: List2d, is_train: bool
+ model: Model[SeqT, SeqT], Xs: List2d, is_train: bool
) -> Tuple[List2d, Callable]:
- layer = model.layers[0]
+ layer: Model[Array2d, Array2d] = model.layers[0]
pad = model.attrs["pad"]
lengths = layer.ops.asarray1i([len(seq) for seq in Xs])
- Xf = layer.ops.flatten(Xs, pad=pad) # type: ignore
+ Xf = layer.ops.flatten(Xs, pad=pad)
Yf, get_dXf = layer(Xf, is_train)
def backprop(dYs: List2d) -> List2d:
- dYf = layer.ops.flatten(dYs, pad=pad) # type: ignore
+ dYf = layer.ops.flatten(dYs, pad=pad)
dXf = get_dXf(dYf)
return layer.ops.unflatten(dXf, lengths, pad=pad)
@@ -86,23 +84,23 @@ def backprop(dYs: List2d) -> List2d:
def _ragged_forward(
- model: Model[Ragged, Ragged], Xr: Ragged, is_train: bool
+ model: Model[SeqT, SeqT], Xr: Ragged, is_train: bool
) -> Tuple[Ragged, Callable]:
- layer: Model[Array2d, ArrayXd] = model.layers[0]
+ layer: Model[Array2d, Array2d] = model.layers[0]
Y, get_dX = layer(Xr.data, is_train)
x_shape = Xr.dataXd.shape
def backprop(dYr: Ragged) -> Ragged:
return Ragged(get_dX(dYr.dataXd).reshape(x_shape), dYr.lengths)
-
+
return Ragged(Y, Xr.lengths), backprop
def _padded_forward(
- model: Model[Padded, Padded], Xp: Padded, is_train: bool
+ model: Model[SeqT, SeqT], Xp: Padded, is_train: bool
) -> Tuple[Padded, Callable]:
layer: Model[Array2d, Array2d] = model.layers[0]
- X = model.ops.reshape2f(
+ X = model.ops.reshape2(
Xp.data, Xp.data.shape[0] * Xp.data.shape[1], Xp.data.shape[2]
)
Y2d, get_dX = layer(X, is_train)
@@ -112,7 +110,7 @@ def _padded_forward(
def backprop(dYp: Padded) -> Padded:
assert isinstance(dYp, Padded)
- dY = model.ops.reshape2f(
+ dY = model.ops.reshape2(
dYp.data, dYp.data.shape[0] * dYp.data.shape[1], dYp.data.shape[2]
)
dX2d = get_dX(dY)
diff --git a/thinc/layers/with_flatten.py b/thinc/layers/with_flatten.py
index 94246c5d2..2ae50f282 100644
--- a/thinc/layers/with_flatten.py
+++ b/thinc/layers/with_flatten.py
@@ -1,34 +1,34 @@
-from typing import Tuple, Callable, Sequence, Any, List, TypeVar
+from typing import Tuple, Callable, Sequence, Any, cast, TypeVar, Optional, List
from ..model import Model
from ..config import registry
-from ..types import Array2d, List2d
+from ..types import ListXd
ItemT = TypeVar("ItemT")
InT = Sequence[Sequence[ItemT]]
-OutT = List2d
+OutT = TypeVar("OutT", bound=ListXd)
@registry.layers("with_flatten.v1")
-def with_flatten(layer: Model) -> Model[InT, OutT]:
+def with_flatten(layer: Model[InT, InT]) -> Model[OutT, OutT]:
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
def forward(
- model: Model[InT, OutT], Xnest: InT, is_train: bool
+ model: Model[OutT, OutT], Xnest: OutT, is_train: bool
) -> Tuple[OutT, Callable]:
- layer: Model[Sequence[Any], Array2d] = model.layers[0]
- Xflat: Sequence[Any] = _flatten(Xnest)
+ layer: Model[InT, InT] = model.layers[0]
+ Xflat: Sequence[Sequence[Any]] = _flatten(Xnest)
Yflat, backprop_layer = layer(Xflat, is_train)
# Get the split points. We want n-1 splits for n items.
arr = layer.ops.asarray1i([len(x) for x in Xnest[:-1]])
splits = arr.cumsum()
Ynest = layer.ops.xp.split(Yflat, splits, axis=0)
- def backprop(dYnest: OutT) -> InT:
- # I think the input/output types might be wrong here?
- dYflat = model.ops.flatten(dYnest) # type: ignore
+ def backprop(dYnest: OutT) -> OutT:
+ dYflat = model.ops.flatten(dYnest) # type: ignore[arg-type, var-annotated]
+ # type ignore necessary for older versions of Mypy/Pydantic
dXflat = backprop_layer(dYflat)
dXnest = layer.ops.xp.split(dXflat, splits, axis=-1)
return dXnest
@@ -36,14 +36,16 @@ def backprop(dYnest: OutT) -> InT:
return Ynest, backprop
-def _flatten(nested: InT) -> List[ItemT]:
- flat: List[ItemT] = []
+def _flatten(nested: OutT) -> InT:
+ flat: List = []
for item in nested:
flat.extend(item)
- return flat
+ return cast(InT, flat)
-def init(model, X=None, Y=None) -> None:
+def init(
+ model: Model[OutT, OutT], X: Optional[OutT] = None, Y: Optional[OutT] = None
+) -> None:
model.layers[0].initialize(
_flatten(X) if X is not None else None,
model.layers[0].ops.xp.hstack(Y) if Y is not None else None,
diff --git a/thinc/layers/with_list.py b/thinc/layers/with_list.py
index 6b1a51fe7..9f86c24dc 100644
--- a/thinc/layers/with_list.py
+++ b/thinc/layers/with_list.py
@@ -1,11 +1,10 @@
from typing import Tuple, Callable, List, Optional, TypeVar, Union, cast
-from ..types import Padded, Ragged, Floats2d, List2d
+from ..types import Padded, Ragged, Array2d, List2d, Floats2d, Ints2d
from ..model import Model
from ..config import registry
-
-SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, List2d])
+SeqT = TypeVar("SeqT", Padded, Ragged, List2d, List[Floats2d], List[Ints2d])
@registry.layers("with_list.v1")
@@ -23,14 +22,12 @@ def forward(
model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
) -> Tuple[SeqT, Callable]:
layer: Model[List2d, List2d] = model.layers[0]
- Y: Union[Padded, Ragged, List2d]
if isinstance(Xseq, Padded):
- Y, backprop = _padded_forward(layer, cast(Padded, Xseq), is_train)
+ return _padded_forward(layer, Xseq, is_train)
elif isinstance(Xseq, Ragged):
- Y, backprop = _ragged_forward(layer, cast(Ragged, Xseq), is_train)
+ return _ragged_forward(layer, Xseq, is_train)
else:
- Y, backprop = layer(cast(List2d, Xseq), is_train)
- return cast(Tuple[SeqT, Callable], (Y, backprop))
+ return cast(Tuple[SeqT, Callable], layer(cast(List2d, Xseq), is_train))
def init(
@@ -51,7 +48,9 @@ def _get_list(model, seq):
return seq
-def _ragged_forward(layer, Xr, is_train):
+def _ragged_forward(
+ layer: Model[List2d, List2d], Xr: Ragged, is_train: bool
+) -> Tuple[Ragged, Callable]:
# Assign these to locals, to keep code a bit shorter.
unflatten = layer.ops.unflatten
flatten = layer.ops.flatten
@@ -62,12 +61,17 @@ def _ragged_forward(layer, Xr, is_train):
Ys, get_dXs = layer(unflatten(Xr.data, Xr.lengths), is_train)
def backprop(dYr: Ragged):
- return Ragged(flatten(get_dXs(unflatten(dYr.data, dYr.lengths))), dYr.lengths)
+ return Ragged(
+ flatten(get_dXs(unflatten(dYr.data, dYr.lengths))),
+ dYr.lengths,
+ )
return Ragged(flatten(Ys), Xr.lengths), backprop
-def _padded_forward(layer, Xp, is_train):
+def _padded_forward(
+ layer: Model[List2d, List2d], Xp: Padded, is_train: bool
+) -> Tuple[Padded, Callable]:
# Assign these to locals, to keep code a bit shorter.
padded2list = layer.ops.padded2list
list2padded = layer.ops.list2padded
@@ -80,4 +84,4 @@ def _padded_forward(layer, Xp, is_train):
def backprop(dYp):
return list2padded(get_dXs(padded2list(dYp)))
- return list2padded(cast(List[Floats2d], Ys)), backprop
+ return list2padded(Ys), backprop
diff --git a/thinc/layers/with_padded.py b/thinc/layers/with_padded.py
index e2e33d5f3..379df1bef 100644
--- a/thinc/layers/with_padded.py
+++ b/thinc/layers/with_padded.py
@@ -1,6 +1,6 @@
-from typing import Tuple, Callable, Optional, TypeVar, Union, cast
+from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List
-from ..types import Padded, Ragged, Array2d, Floats3d, Ints1d, Floats2d, List2d
+from ..types import Padded, Ragged, Floats3d, Ints1d, List2d, Array2d
from ..model import Model
from ..config import registry
from ..util import is_xp_array
@@ -25,18 +25,23 @@ def forward(
model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
) -> Tuple[SeqT, Callable]:
layer: Model[Padded, Padded] = model.layers[0]
- Y: Union[Padded, Ragged, List2d, PaddedData]
if isinstance(Xseq, Padded):
- Y, backprop = layer(Xseq, is_train)
+ return cast(Tuple[SeqT, Callable], layer(Xseq, is_train))
elif isinstance(Xseq, Ragged):
- Y, backprop = _ragged_forward(layer, cast(Ragged, Xseq), is_train)
+ return cast(Tuple[SeqT, Callable], _ragged_forward(layer, Xseq, is_train))
elif _is_padded_data(Xseq):
- Y, backprop = _tuple_forward(layer, cast(PaddedData, Xseq), is_train)
+ return cast(
+ Tuple[SeqT, Callable],
+ _tuple_forward(layer, cast(PaddedData, Xseq), is_train),
+ )
elif is_xp_array(Xseq):
- Y, backprop = _array_forward(layer, cast(Floats3d, Xseq), is_train)
+ return cast(
+ Tuple[SeqT, Callable], _array_forward(layer, cast(Floats3d, Xseq), is_train)
+ )
else:
- Y, backprop = _list_forward(layer, cast(List2d, Xseq), is_train)
- return cast(Tuple[SeqT, Callable], (Y, backprop))
+ return cast(
+ Tuple[SeqT, Callable], _list_forward(layer, cast(List2d, Xseq), is_train)
+ )
def init(
@@ -48,28 +53,31 @@ def init(
)
-def _is_padded_data(seq):
+def _is_padded_data(seq: SeqT) -> bool:
return isinstance(seq, tuple) and len(seq) == 4 and all(map(is_xp_array, seq))
-def _get_padded(model, seq):
+def _get_padded(model: Model, seq: SeqT) -> Padded:
if isinstance(seq, Padded):
return seq
elif isinstance(seq, Ragged):
return model.ops.list2padded(model.ops.unflatten(seq.data, seq.lengths))
elif _is_padded_data(seq):
- return Padded(*seq) # type: ignore
+ return Padded(*seq) # type: ignore[misc]
elif is_xp_array(seq):
- size_at_t = model.ops.asarray1i([seq.shape[1]] * seq.shape[0])
- lengths = model.ops.asarray1i([seq.shape[0]] * seq.shape[1])
- indices = model.ops.xp.arange(seq.shape[1])
- return Padded(seq, size_at_t, lengths, indices)
+ floats3d_seq = cast(Floats3d, seq)
+ size_at_t = model.ops.asarray1i([floats3d_seq.shape[1]] * floats3d_seq.shape[0])
+ lengths = model.ops.asarray1i([floats3d_seq.shape[0]] * floats3d_seq.shape[1])
+ indices = model.ops.xp.arange(floats3d_seq.shape[1])
+ return Padded(floats3d_seq, size_at_t, lengths, indices)
else:
assert isinstance(seq, list), seq
return model.ops.list2padded(seq)
-def _array_forward(layer, X, is_train):
+def _array_forward(
+ layer: Model[Padded, Padded], X: Floats3d, is_train
+) -> Tuple[Floats3d, Callable]:
# Create bogus metadata for Padded.
Xp = _get_padded(layer, X)
Yp, get_dXp = layer(Xp, is_train)
@@ -82,20 +90,24 @@ def backprop(dY: Floats3d) -> Floats3d:
dXp = get_dXp(dYp)
return dXp.data
- return Yp.data, backprop
+ return cast(Floats3d, Yp.data), backprop
-def _tuple_forward(layer, X, is_train: bool):
+def _tuple_forward(
+ layer: Model[Padded, Padded], X: PaddedData, is_train: bool
+) -> Tuple[PaddedData, Callable]:
Yp, get_dXp = layer(Padded(*X), is_train)
def backprop(dY):
dXp = get_dXp(Padded(*dY))
return (dXp.data, dXp.size_at_t, dXp.lengths, dXp.indices)
- return (Yp.data, Yp.size_at_t, Yp.lengths, Yp.indices), backprop
+ return (cast(Floats3d, Yp.data), Yp.size_at_t, Yp.lengths, Yp.indices), backprop
-def _ragged_forward(layer, Xr, is_train):
+def _ragged_forward(
+ layer: Model[Padded, Padded], Xr: Ragged, is_train: bool
+) -> Tuple[Ragged, Callable]:
# Assign these to locals, to keep code a bit shorter.
list2padded = layer.ops.list2padded
padded2list = layer.ops.padded2list
@@ -109,22 +121,24 @@ def _ragged_forward(layer, Xr, is_train):
def backprop(dYr: Ragged):
flattened = flatten(
- padded2list(get_dXp(list2padded(unflatten(dYr.data, dYr.lengths))))
+ padded2list(get_dXp(list2padded(unflatten(dYr.data, dYr.lengths)))),
)
- return Ragged(cast(Floats2d, flattened), dYr.lengths)
+ return Ragged(flattened, dYr.lengths)
flattened = flatten(padded2list(Yp))
return Ragged(flattened, Xr.lengths), backprop
-def _list_forward(layer, Xs, is_train):
+def _list_forward(
+ layer: Model[Padded, Padded], Xs: List2d, is_train: bool
+) -> Tuple[List2d, Callable]:
# Assign these to locals, to keep code a bit shorter.
list2padded = layer.ops.list2padded
padded2list = layer.ops.padded2list
- Yp, get_dXp = layer(list2padded(Xs), is_train) # type: ignore
+ Yp, get_dXp = layer(list2padded(Xs), is_train)
def backprop(dYs):
- return padded2list(get_dXp(list2padded(dYs))) # type: ignore
+ return padded2list(get_dXp(list2padded(dYs)))
return padded2list(Yp), backprop
diff --git a/thinc/layers/with_ragged.py b/thinc/layers/with_ragged.py
index db01832f2..005c69048 100644
--- a/thinc/layers/with_ragged.py
+++ b/thinc/layers/with_ragged.py
@@ -1,12 +1,11 @@
-from typing import Tuple, Callable, Optional, TypeVar, Union, cast
+from typing import Tuple, Callable, Optional, TypeVar, cast, List, Union
-from ..types import Padded, Ragged, Ints1d, Array2d, List2d
+from ..types import Padded, Ragged, Array2d, ListXd, List2d, Ints1d
from ..model import Model
from ..config import registry
-
RaggedData = Tuple[Array2d, Ints1d]
-SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, List2d, RaggedData])
+SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, ListXd, RaggedData])
@registry.layers("with_ragged.v1")
@@ -18,20 +17,25 @@ def forward(
model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
) -> Tuple[SeqT, Callable]:
layer: Model[Ragged, Ragged] = model.layers[0]
- Y: Union[Padded, Ragged, List2d, RaggedData]
if isinstance(Xseq, Ragged):
- Y, backprop = layer(Xseq, is_train)
+ return cast(Tuple[SeqT, Callable], layer(Xseq, is_train))
elif isinstance(Xseq, Padded):
- Y, backprop = _padded_forward(layer, cast(Padded, Xseq), is_train)
+ return cast(Tuple[SeqT, Callable], _padded_forward(layer, Xseq, is_train))
elif _is_ragged_data(Xseq):
- Y, backprop = _tuple_forward(layer, cast(RaggedData, Xseq), is_train)
+ return cast(
+ Tuple[SeqT, Callable],
+ _tuple_forward(layer, cast(RaggedData, Xseq), is_train),
+ )
else:
- Y, backprop = _list_forward(layer, cast(List2d, Xseq), is_train)
- return cast(Tuple[SeqT, Callable], (Y, backprop))
+ return cast(
+ Tuple[SeqT, Callable], _list_forward(layer, cast(List, Xseq), is_train)
+ )
def init(
- model: Model[SeqT, SeqT], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
+ model: Model[SeqT, SeqT],
+ X: Optional[SeqT] = None,
+ Y: Optional[SeqT] = None,
) -> None:
model.layers[0].initialize(
X=_get_ragged(model, X) if X is not None else None,
@@ -39,26 +43,29 @@ def init(
)
-
def _is_ragged_data(seq):
return isinstance(seq, tuple) and len(seq) == 2
-def _get_ragged(model, seq):
+def _get_ragged(model: Model[SeqT, SeqT], seq: SeqT) -> Ragged:
if isinstance(seq, Ragged):
return seq
elif isinstance(seq, Padded):
lists = model.ops.padded2list(seq)
lengths = model.ops.asarray1i([len(x) for x in lists])
+ k = model.ops.flatten(lists)
return Ragged(model.ops.flatten(lists), lengths)
elif _is_ragged_data(seq):
- return Ragged(*seq)
+ return Ragged(*seq) # type: ignore[misc]
else:
- lengths = model.ops.asarray1i([len(x) for x in seq])
- return Ragged(model.ops.flatten(seq), lengths)
+ list2d_seq = cast(List2d, seq)
+ lengths = model.ops.asarray1i([len(x) for x in list2d_seq])
+ return Ragged(model.ops.flatten(list2d_seq), lengths)
-def _tuple_forward(layer: Model[Ragged, Ragged], X: RaggedData, is_train: bool):
+def _tuple_forward(
+ layer: Model[Ragged, Ragged], X: RaggedData, is_train: bool
+) -> Tuple[RaggedData, Callable]:
Yr, get_dXr = layer(Ragged(*X), is_train)
def backprop(dY: RaggedData) -> RaggedData:
@@ -68,7 +75,9 @@ def backprop(dY: RaggedData) -> RaggedData:
return (Yr.data, Yr.lengths), backprop
-def _padded_forward(layer, Xp, is_train):
+def _padded_forward(
+ layer: Model[Ragged, Ragged], Xp: Padded, is_train: bool
+) -> Tuple[Padded, Callable]:
# Assign these to locals, to keep code a bit shorter.
list2padded = layer.ops.list2padded
padded2list = layer.ops.padded2list
@@ -86,12 +95,18 @@ def _padded_forward(layer, Xp, is_train):
def backprop(dYp: Padded):
flattened = flatten(padded2list(dYp))
- return list2padded(unflatten(get_dXr(Ragged(flattened, lengths)).data, lengths))
+ dXr = get_dXr(Ragged(flattened, lengths))
+ return list2padded(unflatten(dXr.data, lengths))
- return list2padded(unflatten(Yr.data, Yr.lengths)), backprop
+ return (
+ list2padded(unflatten(Yr.data, Yr.lengths)),
+ backprop,
+ )
-def _list_forward(layer, Xs, is_train: bool):
+def _list_forward(
+ layer: Model[Ragged, Ragged], Xs: List, is_train: bool
+) -> Tuple[List, Callable]:
# Assign these to locals, to keep code a bit shorter.
flatten = layer.ops.flatten
unflatten = layer.ops.unflatten
diff --git a/thinc/layers/with_reshape.py b/thinc/layers/with_reshape.py
index bc7dc521b..5bd3e9025 100644
--- a/thinc/layers/with_reshape.py
+++ b/thinc/layers/with_reshape.py
@@ -1,15 +1,16 @@
-from typing import Tuple, Callable, Optional, cast
+from typing import Tuple, Callable, Optional, cast, TypeVar, List
from ..model import Model
from ..config import registry
-from ..types import Array3d, Array2d, Floats3d
+from ..types import Array3d, Array2d
-InT = Array3d
+InT = TypeVar("InT", bound=Array3d)
+OutT = TypeVar("OutT", bound=Array2d)
@registry.layers("with_reshape.v1")
-def with_reshape(layer: Model[Array2d, Array2d]) -> Model[InT, InT]:
+def with_reshape(layer: Model[OutT, OutT]) -> Model[InT, InT]:
"""Reshape data on the way into and out from a layer."""
return Model(
f"with_reshape({layer.name})",
@@ -26,16 +27,15 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab
final_shape = list(initial_shape[:-1]) + [layer.get_dim("nO")]
nB = X.shape[0]
nT = X.shape[1]
- X2d = cast(InT, model.ops.reshape(X, (-1, X.shape[2])))
+ X2d = model.ops.reshape(X, (-1, X.shape[2]))
Y2d, Y2d_backprop = layer(X2d, is_train=is_train)
- Y = model.ops.reshape3f(Y2d, *final_shape)
+ Y = model.ops.reshape3(Y2d, *final_shape)
def backprop(dY: InT) -> InT:
- dY_floats = model.ops.asarray3f(cast(Floats3d, dY))
- reshaped = model.ops.reshape2f(dY_floats, nB * nT, -1)
- return Y2d_backprop(model.ops.reshape3f(reshaped, *initial_shape))
+ reshaped = model.ops.reshape2(dY, nB * nT, -1)
+ return Y2d_backprop(model.ops.reshape3(reshaped, *initial_shape))
- return Y, backprop
+ return cast(InT, Y), backprop
def init(
diff --git a/thinc/loss.py b/thinc/loss.py
index 3632e93cf..990b30df1 100644
--- a/thinc/loss.py
+++ b/thinc/loss.py
@@ -1,4 +1,4 @@
-from typing import Tuple, List, cast, TypeVar, Generic, Any, Union, Optional
+from typing import Tuple, Sequence, cast, TypeVar, Generic, Any, Union, Optional, List
from typing import Dict
from .types import Floats2d, Ints1d
@@ -11,7 +11,7 @@
GuessT = TypeVar("GuessT")
TruthT = TypeVar("TruthT")
IntsOrFloats = Union[Ints1d, Floats2d]
-IntsOrFloatsOrStrs = Union[Ints1d, Floats2d, List[int], List[str]]
+IntsOrFloatsOrStrs = Union[Ints1d, Floats2d, Sequence[int], Sequence[str]]
class Loss(Generic[GuessT, TruthT, GradT, LossT]): # pragma: no cover
@@ -35,7 +35,7 @@ def get_loss(self, guesses: GuessT, truths: TruthT) -> LossT:
class CategoricalCrossentropy(Loss):
- names: Optional[List[str]]
+ names: Optional[Sequence[str]]
missing_value: Optional[Union[str, int]]
_name_to_i: Dict[str, int]
@@ -43,7 +43,7 @@ def __init__(
self,
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
missing_value: Optional[Union[str, int]] = None,
neg_prefix: Optional[str] = None,
label_smoothing: float = 0.0,
@@ -160,7 +160,7 @@ def _get_loss_from_grad(self, d_truth: Floats2d) -> float:
def configure_CategoricalCrossentropy_v1(
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
missing_value: Optional[Union[str, int]] = None,
) -> CategoricalCrossentropy:
return CategoricalCrossentropy(
@@ -172,7 +172,7 @@ def configure_CategoricalCrossentropy_v1(
def configure_CategoricalCrossentropy_v2(
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
missing_value: Optional[Union[str, int]] = None,
neg_prefix: Optional[str] = None,
) -> CategoricalCrossentropy:
@@ -188,7 +188,7 @@ def configure_CategoricalCrossentropy_v2(
def configure_CategoricalCrossentropy_v3(
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
missing_value: Optional[Union[str, int]] = None,
neg_prefix: Optional[str] = None,
label_smoothing: float = 0.0,
@@ -207,7 +207,7 @@ def __init__(
self,
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
missing_value: Optional[Union[str, int]] = None,
neg_prefix: Optional[str] = None,
label_smoothing: float = 0.0,
@@ -222,14 +222,14 @@ def __init__(
self.normalize = normalize
def __call__(
- self, guesses: List[Floats2d], truths: List[IntsOrFloatsOrStrs]
+ self, guesses: Sequence[Floats2d], truths: Sequence[IntsOrFloatsOrStrs]
) -> Tuple[List[Floats2d], float]:
grads = self.get_grad(guesses, truths)
loss = self._get_loss_from_grad(grads)
return grads, loss
def get_grad(
- self, guesses: List[Floats2d], truths: List[IntsOrFloatsOrStrs]
+ self, guesses: Sequence[Floats2d], truths: Sequence[IntsOrFloatsOrStrs]
) -> List[Floats2d]:
err = "Cannot calculate SequenceCategoricalCrossentropy loss: guesses and truths must be same length"
if len(guesses) != len(truths): # pragma: no cover
@@ -244,11 +244,11 @@ def get_grad(
return d_scores
def get_loss(
- self, guesses: List[Floats2d], truths: List[IntsOrFloatsOrStrs]
+ self, guesses: Sequence[Floats2d], truths: Sequence[IntsOrFloatsOrStrs]
) -> float:
return self._get_loss_from_grad(self.get_grad(guesses, truths))
- def _get_loss_from_grad(self, grads: List[Floats2d]) -> float:
+ def _get_loss_from_grad(self, grads: Sequence[Floats2d]) -> float:
loss = 0.0
for grad in grads:
loss += self.cc._get_loss_from_grad(grad)
@@ -257,7 +257,7 @@ def _get_loss_from_grad(self, grads: List[Floats2d]) -> float:
@registry.losses("SequenceCategoricalCrossentropy.v1")
def configure_SequenceCategoricalCrossentropy_v1(
- *, normalize: bool = True, names: Optional[List[str]] = None
+ *, normalize: bool = True, names: Optional[Sequence[str]] = None
) -> SequenceCategoricalCrossentropy:
return SequenceCategoricalCrossentropy(normalize=normalize, names=names)
@@ -266,7 +266,7 @@ def configure_SequenceCategoricalCrossentropy_v1(
def configure_SequenceCategoricalCrossentropy_v2(
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
neg_prefix: Optional[str] = None,
) -> SequenceCategoricalCrossentropy:
return SequenceCategoricalCrossentropy(
@@ -278,7 +278,7 @@ def configure_SequenceCategoricalCrossentropy_v2(
def configure_SequenceCategoricalCrossentropy_v3(
*,
normalize: bool = True,
- names: Optional[List[str]] = None,
+ names: Optional[Sequence[str]] = None,
missing_value: Optional[Union[str, int]] = None,
neg_prefix: Optional[str] = None,
label_smoothing: float = 0.0,
diff --git a/thinc/types.py b/thinc/types.py
index f26789399..74498d159 100644
--- a/thinc/types.py
+++ b/thinc/types.py
@@ -42,7 +42,7 @@
List2d = Union[List["Floats2d"], List["Ints2d"]]
List3d = Union[List["Floats3d"], List["Ints3d"]]
List4d = Union[List["Floats4d"], List["Ints4d"]]
-ListXd = Union[List["FloatsXd"], List["IntsXd"]]
+ListXd = Union[List1d, List2d, List3d, List4d]
ArrayT = TypeVar("ArrayT")
SelfT = TypeVar("SelfT")
@@ -712,7 +712,31 @@ def __get_validators__(cls):
yield lambda v: validate_array(v, ndim=4, dtype="i")
def __iter__(self) -> Iterator[Ints3d]: ...
- # def __getitem__(self, key: int) -> Ints3d: ...
+
+ @overload
+ def __getitem__(self, key: _4_KeyScalar) -> int: ...
+ @overload
+ def __getitem__(self, key: _4_Key1d) -> Ints1d: ...
+ @overload
+ def __getitem__(self, key: _4_Key2d) -> Ints2d: ...
+ @overload
+ def __getitem__(self, key: _4_Key3d) -> Ints3d: ...
+ @overload
+ def __getitem__(self, key: _4_Key4d) -> "Ints4d": ...
+ def __getitem__(self, key: _4_AllKeys) -> _I4_AllReturns: ...
+
+ @overload
+ def __setitem__(self, key: _4_KeyScalar, value: int) -> None: ...
+ @overload
+ def __setitem__(self, key: _4_Key1d, value: Ints1d) -> None: ...
+ @overload
+ def __setitem__(self, key: _4_Key2d, value: Ints2d) -> None: ...
+ @overload
+ def __setitem__(self, key: _4_Key3d, value: Ints3d) -> None: ...
+ @overload
+ def __setitem__(self, key: _4_Key4d, value: "Ints4d") -> None: ...
+
+ def __setitem__(self, key: _4_AllKeys, value: _I4_AllReturns) -> None: ...
@overload
def sum(self, *, keepdims: Tru, axis: _4_AllAx = None, out: Optional["Ints4d"] = None) -> "Ints4d": ...
@@ -782,7 +806,7 @@ class Padded:
and the indices indicates the original ordering.
"""
- data: Floats3d
+ data: Array3d
size_at_t: Ints1d
lengths: Ints1d
indices: Ints1d
diff --git a/thinc/util.py b/thinc/util.py
index 9dfc755cc..dacfb09c2 100644
--- a/thinc/util.py
+++ b/thinc/util.py
@@ -1,5 +1,5 @@
from typing import Any, Union, Sequence, cast, Dict, Optional, Callable, TypeVar
-from typing import List, Tuple
+from typing import List, Mapping, Tuple
import numpy
from packaging.version import Version
import random
@@ -490,7 +490,11 @@ def partial(
class DataValidationError(ValueError):
def __init__(
- self, name: str, X: Any, Y: Any, errors: List[Dict[str, Any]] = []
+ self,
+ name: str,
+ X: Any,
+ Y: Any,
+ errors: Union[Sequence[Mapping[str, Any]], List[Dict[str, Any]]] = [],
) -> None:
"""Custom error for validating inputs / outputs at runtime."""
message = f"Data validation error in '{name}'"
@@ -564,7 +568,7 @@ def data_validation(validation):
@contextlib.contextmanager
-def use_nvtx_range(message: int, id_color: int = -1):
+def use_nvtx_range(message: str, id_color: int = -1):
"""Context manager to register the executed code as an NVTX range. The
ranges can be used as markers in CUDA profiling."""
if has_cupy:
diff --git a/website/docs/api-layers.md b/website/docs/api-layers.md
index b4c51e4b8..1c43a9d7a 100644
--- a/website/docs/api-layers.md
+++ b/website/docs/api-layers.md
@@ -48,8 +48,10 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/cauchysimilarity.py
-- **Input:** ArrayXd
-- **Output:** ArrayXd
+- **Input:** ArrayXd / Sequence[ArrayXd] /
+ Ragged / Padded
+- **Output:** ArrayXd / Sequence[ArrayXd]
+ / Ragged / Padded
- **Attrs:** `dropout_rate` float
@@ -71,10 +73,10 @@ for node in model.walk():
node.attrs["dropout_rate"] = 0.5
```
-| Argument | Type | Description |
-| -------------- | -------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
-| `dropout_rate` | float | The probability of zeroing the activations (default: 0). Higher dropout rates mean more distortion. Values around `0.2` are often good. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The created dropout layer. |
+| Argument | Type | Description |
+| -------------- | -------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
+| `dropout_rate` | float | The probability of zeroing the activations (default: 0). Higher dropout rates mean more distortion. Values around `0.2` are often good. |
+| **RETURNS** | Model[T, T] | The created dropout layer. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/dropout.py
@@ -84,8 +86,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/dropout.py
-- **Input:** Ints1d /
- Ints2d
+- **Input:** Union[Ints1d, Ints2d]
- **Output:** Floats2d
- **Parameters:** E
- **Attrs:** `column` int, `dropout_rate` float
@@ -114,8 +115,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/embed.py
-- **Input:** Ints1d /
- Ints2d
+- **Input:** Union[Ints1d, Ints2d] /
- **Output:** Floats2d
- **Parameters:** E
- **Attrs:** `seed` Optional[int], `column` int,
@@ -238,8 +238,8 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/sigmoid.py
-- **Input:** Floats2d
-- **Output:** Floats2d
+- **Input:** FloatsXd
+- **Output:** FloatsXd
@@ -252,7 +252,7 @@ element of the output vectors will be between `0` and `1`.
| **RETURNS** | Model[Floats2d, Floats2d] | The created `sigmoid_activation` layer. |
```python
-https://github.com/explosion/thinc/blob/master/thinc/layers/sigmoid_logistic.py
+https://github.com/explosion/thinc/blob/master/thinc/layers/sigmoid_activation.py
```
### LSTM and BiLSTM {#lstm tag="function"}
@@ -779,34 +779,6 @@ of the lengths should equal the length of the keys and values array.
https://github.com/explosion/thinc/blob/master/thinc/layers/sparselinear.pyx
```
-### StaticVectors {#staticvectors tag="function"}
-
-
-
-- **Input:** Ints2d
-- **Output:** Floats2d
-- **Attrs:** `column` int, `vectors` Optional[Floats2d],
- `dropout_rate` float
-
-
-
-
-
-| Argument | Type | Description |
-| -------------- | -------------------------------- | --------------------------------------------------- |
-| `nO` | Optional[int] | The size of the output vectors. |
-| `vectors` | Optional[Floats2d] | The vectors. |
-| _keyword-only_ | | |
-| `column` | int | The column of values to slice for the indices. |
-| `dropout` | Optional[float] | Dropout rate to avoid overfitting (default `None`). |
-| **RETURNS** | Model[Ints2d, Floats2d] | The created embedding layer. |
-
-```python
-https://github.com/explosion/thinc/blob/master/thinc/layers/staticvectors.py
-```
-
----
-
## Reduction operations {#reduction-ops}
### reduce_first {#reduce_first tag="function"}
@@ -814,7 +786,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/staticvectors.py
- **Input:** Ragged
-- **Output:** Floats2d
+- **Output:** ArrayXd
@@ -823,9 +795,9 @@ item of each sequence. This is most useful after multi-head attention layers,
which can learn to assign a good feature representation for the sequence to one
of its elements.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | -------------------------- |
-| **RETURNS** | Model[Ragged, Floats2d] | The created pooling layer. |
+| Argument | Type | Description |
+| ----------- | ------------------------------- | -------------------------- |
+| **RETURNS** | Model[Ragged, ArrayXd] | The created pooling layer. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_first.py
@@ -841,13 +813,13 @@ representation for the sequence to its final element.
- **Input:** Ragged
-- **Output:** Floats2d
+- **Output:** ArrayXd
-| Argument | Type | Description |
-| ----------- | -------------------------------- | -------------------------- |
-| **RETURNS** | Model[Ragged, Floats2d] | The created pooling layer. |
+| Argument | Type | Description |
+| ----------- | ------------------------------- | -------------------------- |
+| **RETURNS** | Model[Ragged, ArrayXd] | The created pooling layer. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_last.py
@@ -930,10 +902,10 @@ to `>>` allows you to write `Relu(512) >> Softmax()` instead of
Compose two or more models `f`, `g`, etc, such that their outputs are added,
i.e. `add(f, g)(x)` computes `f(x) + g(x)`.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | ---------------------- |
-| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. |
+| Argument | Type | Description |
+| ----------- | ---------------------------- | ---------------------- |
+| `*layers` | Model[Any, ArrayXd] | The models to compose. |
+| **RETURNS** | Model[Any, ArrayXd] | The composed model. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/add.py
@@ -955,13 +927,15 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/bidirectional.py
### chain {#chain tag="function"}
-Compose two models `f` and `g` such that they become layers of a single
-feed-forward model that computes `g(f(x))`.
+Compose two or more models such that they become layers of a single feed-forward
+model, e.g. `chain(f, g)` computes `g(f(x))`.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | -------------------------------- |
-| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed feed-forward model. |
+| Argument | Type | Description |
+| ----------- | -------------- | --------------------------------- |
+| `layer1 ` | Model | The first model to compose. |
+| `layer2` | Model | The second model to compose. |
+| `*layers` | Model | Any additional models to compose. |
+| **RETURNS** | Model | The composed feed-forward model. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/chain.py
@@ -972,11 +946,11 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/chain.py
Construct `n` copies of a layer, with distinct weights. For example,
`clone(f, 3)(x)` computes `f(f'(f''(x)))`.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | ---------------------------------- |
-| `orig` | Model[ArrayXd, ArrayXd] | The layer to copy. |
-| `n` | int | The number of copies to construct. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. |
+| Argument | Type | Description |
+| ----------- | -------------- | ------------------------------------------------ |
+| `orig` | Model | The layer to copy. |
+| `n` | int | The number of copies to construct. |
+| **RETURNS** | Model | A composite model containing two or more copies. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/clone.py
@@ -987,10 +961,10 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/clone.py
Compose two or more models `f`, `g`, etc, such that their outputs are
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | ---------------------- |
-| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. |
+| Argument | Type | Description |
+| ----------- | ------------------- | ---------------------- |
+| `*layers` | Model, ... | The models to compose. |
+| **RETURNS** | Model | The composed model. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/concatenate.py
@@ -1025,10 +999,10 @@ and a window of surrounding vectors. This is one step in a convolution. If the
concatenating three contextual vectors from the left, and three from the right,
to each input vector. In general, `nO` equals `nI * (2 * window_size + 1)`.
-| Argument | Type | Description |
-| ------------- | ------------------------ | ------------------------------------------------------------------------------ |
-| `window_size` | int | The window size (default 1) that determines the number of surrounding vectors. |
-| **RETURNS** | Model[InT, InT] | The created layer for adding context to vectors. |
+| Argument | Type | Description |
+| ------------- | -------------------- | ------------------------------------------------------------------------------ |
+| `window_size` | int | The window size (default 1) that determines the number of surrounding vectors. |
+| **RETURNS** | Model[T, T] | The created layer for adding context to vectors. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/expand_window.py
@@ -1038,10 +1012,10 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/expand_window.py
Transform a sequences of layers into a null operation.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | ---------------------- |
-| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. |
+| Argument | Type | Description |
+| ----------- | -------------- | ---------------------- |
+| `*layers` | Model | The models to compose. |
+| **RETURNS** | Model | The composed model. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/noop.py
@@ -1051,8 +1025,14 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/noop.py
-- **Input:** List[FloatsXd], Ragged, Padded, FloatsXd
-- **Output:** List[FloatsXd], Ragged, Padded, FloatsXd
+- **Input:** List[FloatsXd] / Ragged /
+ Padded / FloatsXd
+ Floats1d Floats2d
+ Floats3d Floats4d
+- **Output:** List[FloatsXd] / Ragged /
+ Padded / FloatsXd
+ Floats1d Floats2d
+ Floats3d Floats4d
@@ -1078,10 +1058,10 @@ input to a downstream layer.
On the backward pass the loss from each child is added together, so when using
custom datatypes they should define an addition operator.
-| Argument | Type | Description |
-| ----------- | -------------------------------- | -------------------------------- |
-| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. |
-| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed feed-forward model. |
+| Argument | Type | Description |
+| ----------- | ----------------------------- | -------------------------------- |
+| `*layers` | Model[Any, T] ... | The models to compose. |
+| **RETURNS** | Model[Any, Tuple[T]] | The composed feed-forward model. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/tuplify.py
@@ -1112,12 +1092,12 @@ minibatch. The `uniqued` wrapper is useful for word inputs, because common words
are seen often, but we may want to compute complicated features for the words,
using e.g. character LSTM.
-| Argument | Type | Description |
-| -------------- | --------------------------------- | ---------------------------- |
-| `layer` | Model | The layer. |
-| _keyword-only_ | | |
-| `column` | int | The column. Defaults to `0`. |
-| **RETURNS** | Model[ArrayXd, FloatsXd] | The composed model. |
+| Argument | Type | Description |
+| -------------- | -------------------------------- | ---------------------------- |
+| `layer` | Model | The layer. |
+| _keyword-only_ | | |
+| `column` | int | The column. Defaults to `0`. |
+| **RETURNS** | Model[Ints2d, Floats2d] | The composed model. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/uniqued.py
@@ -1153,7 +1133,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/array_getitem.py
-- **Input:** List[Array2d]
+- **Input:** List2d
- **Output:** Array2d
@@ -1162,9 +1142,9 @@ Transform sequences to ragged arrays if necessary. If sequences are already
ragged, do nothing. A ragged array is a tuple `(data, lengths)`, where `data` is
the concatenated data.
-| Argument | Type | Description |
-| ----------- | -------------------------------------- | ---------------------------------------- |
-| **RETURNS** | Model[List[Array2d], Array2d] | The layer to compute the transformation. |
+| Argument | Type | Description |
+| ----------- | ------------------------------- | ---------------------------------------- |
+| **RETURNS** | Model[List2d, Array2d] | The layer to compute the transformation. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/list2array.py
@@ -1174,7 +1154,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2array.py
-- **Input:** List[Floats2d]
+- **Input:** ListXd
- **Output:** Ragged
@@ -1183,9 +1163,9 @@ Transform sequences to ragged arrays if necessary and return the ragged array.
If sequences are already ragged, do nothing. A ragged array is a tuple
`(data, lengths)`, where `data` is the concatenated data.
-| Argument | Type | Description |
-| ----------- | ------------------------------------- | ---------------------------------------- |
-| **RETURNS** | Model[List[Array2d], Ragged] | The layer to compute the transformation. |
+| Argument | Type | Description |
+| ----------- | ------------------------------ | ---------------------------------------- |
+| **RETURNS** | Model[ListXd, Ragged] | The layer to compute the transformation. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/list2ragged.py
@@ -1195,7 +1175,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2ragged.py
-- **Input:** List[Array2d]
+- **Input:** List2d
- **Output:** Padded
@@ -1203,9 +1183,9 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2ragged.py
Create a layer to convert a list of array inputs into
[`Padded`](/docs/api-types#padded).
-| Argument | Type | Description |
-| ----------- | ------------------------------------- | ---------------------------------------- |
-| **RETURNS** | Model[List[Array2d], Padded] | The layer to compute the transformation. |
+| Argument | Type | Description |
+| ----------- | ------------------------------ | ---------------------------------------- |
+| **RETURNS** | Model[List2d, Padded] | The layer to compute the transformation. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/list2padded.py
@@ -1216,15 +1196,15 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2padded.py
- **Input:** Ragged
-- **Output:** List[Floats2d]
+- **Output:** ListXd
Transform sequences from a ragged format into lists.
-| Argument | Type | Description |
-| ----------- | -------------------------------------- | ---------------------------------------- |
-| **RETURNS** | Model[Ragged, List[Floats2d]] | The layer to compute the transformation. |
+| Argument | Type | Description |
+| ----------- | ------------------------------ | ---------------------------------------- |
+| **RETURNS** | Model[Ragged, ListXd] | The layer to compute the transformation. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/ragged2list.py
@@ -1235,16 +1215,16 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/ragged2list.py
- **Input:** Padded
-- **Output:** List[Array]
+- **Output:** List2d
Create a layer to convert a [`Padded`](/docs/api-types#padded) input into a list
of arrays.
-| Argument | Type | Description |
-| ----------- | ----------------------------------- | ---------------------------------------- |
-| **RETURNS** | Model[Padded, List[Array]] | The layer to compute the transformation. |
+| Argument | Type | Description |
+| ----------- | ------------------------------ | ---------------------------------------- |
+| **RETURNS** | Model[Padded, List2d] | The layer to compute the transformation. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/padded2list.py
@@ -1298,7 +1278,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/strings2arrays.py
-- **Input / output:** Union[Padded, Ragged, List[Array2d], ArrayXd]
+- **Input / output:** Union[Padded, Ragged, ListXd, ArrayXd]
@@ -1308,7 +1288,7 @@ input is an array, it is passed through unchanged.
| Argument | Type | Description |
| -------------- | -------------------------------- | ----------------------------- |
-| `layer` | Model[Array2d, Array2d] | The layer to wrap. |
+| `layer` | Model[ArrayXd, ArrayXd] | The layer to wrap. |
| _keyword-only_ | | |
| `pad` | int | The padding. Defaults to `0`. |
| **RETURNS** | Model | The wrapped layer. |
@@ -1321,7 +1301,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_array2d.py
-- **Input / output:** Union[Padded, Ragged, List[Array2d], Array2d]
+- **Input / output:** Union[Padded, Ragged, List2d, Array2d]
@@ -1348,17 +1328,17 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_array.py
- **Input:** Sequence[Sequence[Any]]
-- **Output:** List[Array2d]
+- **Output:** ListXd
Flatten nested inputs on the way into a layer and reverse the transformation
over the outputs.
-| Argument | Type | Description |
-| ----------- | -------------- | ------------------ |
-| `layer` | Model | The layer to wrap. |
-| **RETURNS** | Model | The wrapped layer. |
+| Argument | Type | Description |
+| ----------- | ---------------------------------------------------------------- | ------------------ |
+| `layer` | Model[Sequence[Sequence[Any]], Sequence[Sequence[Any]]] | The layer to wrap. |
+| **RETURNS** | Model[ListXd, ListXd] | The wrapped layer. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/with_flatten.py
@@ -1368,7 +1348,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_flatten.py
-- **Input / output:** Union[Padded, Ragged, List[Array2d], Floats3d,
+- **Input / output:** Union[Padded, Ragged, List2d, Floats3d,
Tuple[Floats3d, Ints1d, Ints1d, Ints1d]]
@@ -1389,7 +1369,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_padded.py
-- **Input / output:** Union[Padded, Ragged, List[Array2d], Floats3d,
+- **Input / output:** Union[Padded, Ragged, ListXd, Floats3d,
Tuple[Floats2d, Ints1d]]
@@ -1410,17 +1390,17 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_ragged.py
-- **Input / output:** Union[Padded, Ragged, List[Array2d]]
+- **Input / output:** Union[Padded, Ragged, List2d]
Convert sequence input into lists on the way into a layer and reverse the
transformation on the outputs.
-| Argument | Type | Description |
-| ----------- | -------------------------------------------- | ------------------ |
-| `layer` | Model[List[Array2d], List[Array2d]] | The layer to wrap. |
-| **RETURNS** | Model | The wrapped layer. |
+| Argument | Type | Description |
+| ----------- | ------------------------------ | ------------------ |
+| `layer` | Model[List2d, List2d] | The layer to wrap. |
+| **RETURNS** | Model | The wrapped layer. |
```python
https://github.com/explosion/thinc/blob/master/thinc/layers/with_list.py