diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c88f78c92..301ab7341 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -61,7 +61,6 @@ jobs: displayName: 'Build sdist' - script: | - python -m pip install mypy==0.910 python -m mypy thinc displayName: 'Run mypy' diff --git a/requirements.txt b/requirements.txt index 1d8b9491f..4a31fc017 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ wasabi>=0.8.1,<1.1.0 catalogue>=2.0.4,<2.1.0 ml_datasets>=0.2.0,<0.3.0 # Third-party dependencies -pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0 +pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0 numpy>=1.15.0 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" @@ -22,8 +22,7 @@ pytest-cov>=2.7.0,<2.8.0 coverage>=5.0.0,<6.0.0 mock>=2.0.0,<3.0.0 flake8>=3.5.0,<3.6.0 -# restricting mypy until faster 3.10 wheels are available -mypy>=0.901,<0.920; python_version < "3.10" +mypy>=0.901,<0.960 types-mock>=0.1.1 types-contextvars>=0.1.2; python_version < "3.7" types-dataclasses>=0.1.3; python_version < "3.7" diff --git a/setup.cfg b/setup.cfg index 59b69419a..4e535d899 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ python_requires = >=3.6 setup_requires = cython>=0.25,<3.0 numpy>=1.15.0 - # We also need our Cython packages here to compile against + # We also need our Cython packages here to compile against cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 murmurhash>=1.0.2,<1.1.0 @@ -48,7 +48,7 @@ install_requires = # Third-party dependencies setuptools numpy>=1.15.0 - pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0 + pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8" diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index 1d1a374f5..315b0b0bf 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -5,9 +5,9 @@ import numpy import itertools -from .. import registry from ..types import Xp, Shape, DTypes, DTypesInt, DTypesFloat, List2d, ArrayXd -from ..types import Array3d, Floats1d, Floats2d, Floats3d, Floats4d +from ..types import Floats1d, Floats2d, Floats3d, Floats4d +from ..types import Array1d, Array2d, Array3d, Array4d, ListXd from ..types import FloatsXd, Ints1d, Ints2d, Ints3d, Ints4d, IntsXd, _Floats from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator from ..util import get_array_module, is_xp_array, to_numpy @@ -135,13 +135,11 @@ def _get_batch(self, sequence, indices): if isinstance(sequence, list): subseq = [sequence[i] for i in indices] elif isinstance(sequence, tuple): - subseq = tuple(sequence[i] for i in indices) # type: ignore + subseq = tuple(sequence[i] for i in indices) else: - subseq = sequence[indices] # type: ignore + subseq = sequence[indices] if is_xp_array(subseq): - subseq = self.as_contig( - cast(ArrayXd, self.xp.asarray(subseq)) - ) # type: ignore + subseq = self.as_contig(self.xp.asarray(subseq)) return subseq def _get_batch_sizes(self, length: int, sizes: Iterator[int]): @@ -225,13 +223,65 @@ def affine(self, X: Floats2d, W: Floats2d, b: Floats1d) -> Floats2d: Y += b return Y + @overload def flatten( self, - X: Sequence[ArrayT], + X: List[Floats2d], dtype: Optional[DTypes] = None, pad: int = 0, ndim_if_empty: int = 2, - ) -> ArrayT: + ) -> Floats2d: + ... + + @overload + def flatten( + self, + X: List[Ints1d], + dtype: Optional[DTypes] = None, + pad: int = 0, + ndim_if_empty: int = 2, + ) -> Ints1d: + ... + + @overload + def flatten( + self, + X: List2d, + dtype: Optional[DTypes] = None, + pad: int = 0, + ndim_if_empty: int = 2, + ) -> Array2d: + ... + + # further specific typed signatures can be added as necessary + + @overload + def flatten( + self, + X: ListXd, + dtype: Optional[DTypes] = None, + pad: int = 0, + ndim_if_empty: int = 2, + ) -> ArrayXd: + ... + + @overload + def flatten( + self, + X: Sequence[ArrayXd], + dtype: Optional[DTypes] = None, + pad: int = 0, + ndim_if_empty: int = 2, + ) -> ArrayXd: + ... + + def flatten( + self, + X: Sequence[ArrayXd], + dtype: Optional[DTypes] = None, + pad: int = 0, + ndim_if_empty: int = 2, + ) -> ArrayXd: """Flatten a list of arrays into one large array.""" if X is None or len(X) == 0: return self.alloc((0,) * ndim_if_empty, dtype=dtype or "f") @@ -252,7 +302,25 @@ def flatten( result = xp.asarray(result, dtype=dtype) return result + @overload def unflatten(self, X: Floats2d, lengths: Ints1d, pad: int = 0) -> List[Floats2d]: + ... + + @overload + def unflatten(self, X: Ints1d, lengths: Ints1d, pad: int = 0) -> List[Ints1d]: + ... + + @overload + def unflatten(self, X: Array2d, lengths: Ints1d, pad: int = 0) -> List2d: + ... + + # further specific typed signatures can be added as necessary + + @overload + def unflatten(self, X: ArrayXd, lengths: Ints1d, pad: int = 0) -> ListXd: + ... + + def unflatten(self, X: ArrayXd, lengths: Ints1d, pad: int = 0) -> ListXd: """The reverse/backward operation of the `flatten` function: unflatten a large array into a list of arrays according to the given lengths. """ @@ -302,7 +370,7 @@ def pad( # noqa: F811 output: Array3d = self.alloc(final_shape, dtype=seqs[0].dtype) for i, arr in enumerate(seqs): # It's difficult to convince this that the dtypes will match. - output[i, : arr.shape[0]] = arr # type: ignore + output[i, : arr.shape[0]] = arr # type: ignore[assignment, call-overload] return output def unpad(self, padded: Array3d, lengths: List[int]) -> List2d: @@ -314,14 +382,14 @@ def unpad(self, padded: Array3d, lengths: List[int]) -> List2d: output.append(padded[i, :length]) return cast(List2d, output) - def list2padded(self, seqs: List[Floats2d]) -> Padded: + def list2padded(self, seqs: List2d) -> Padded: """Pack a sequence of 2d arrays into a Padded datatype.""" if not seqs: return Padded( self.alloc3f(0, 0, 0), self.alloc1i(0), self.alloc1i(0), self.alloc1i(0) ) elif len(seqs) == 1: - data = self.reshape3f(seqs[0], seqs[0].shape[0], 1, seqs[0].shape[1]) + data = self.reshape3(seqs[0], seqs[0].shape[0], 1, seqs[0].shape[1]) size_at_t = self.asarray1i([1] * data.shape[0]) lengths = self.asarray1i([data.shape[0]]) indices = self.asarray1i([0]) @@ -336,8 +404,8 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded: # Reorder the sequences, by length. This looks the same in either # direction: you're swapping elements between their original and sorted # position. - seqs = [seqs[i] for i in indices_] - arr: Floats3d = self.pad(seqs) + seqs = cast(List2d, [seqs[i] for i in indices_]) + arr: Array3d = self.pad(seqs) assert arr.shape == (nB, nS, nO), (nB, nS, nO) arr = self.as_contig(arr.transpose((1, 0, 2))) assert arr.shape == (nS, nB, nO) @@ -350,7 +418,7 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded: batch_size_at_t_[t] = current_size assert sum(lengths_) == sum(batch_size_at_t_) return Padded( - cast(Floats3d, arr), + arr, self.asarray1i(batch_size_at_t_), self.asarray1i(lengths_), self.asarray1i(indices_), @@ -361,7 +429,7 @@ def padded2list(self, padded: Padded) -> List2d: data = padded.data indices = to_numpy(padded.indices) lengths = to_numpy(padded.lengths) - unpadded: List[Optional[Floats2d]] = [None] * len(lengths) + unpadded: List[Optional[Array2d]] = [None] * len(lengths) # Transpose from (length, batch, data) to (batch, length, data) data = self.as_contig(data.transpose((1, 0, 2))) for i in range(data.shape[0]): @@ -500,6 +568,18 @@ def alloc( else: return self.xp.empty(shape, dtype=dtype) + def reshape1(self, array: ArrayXd, d0: int) -> Array1d: + return cast(Array1d, self.reshape(array, (d0,))) + + def reshape2(self, array: ArrayXd, d0: int, d1: int) -> Array2d: + return cast(Array2d, self.reshape(array, (d0, d1))) + + def reshape3(self, array: ArrayXd, d0: int, d1: int, d2: int) -> Array3d: + return cast(Array3d, self.reshape(array, (d0, d1, d2))) + + def reshape4(self, array: ArrayXd, d0: int, d1: int, d2: int, d3: int) -> Array4d: + return cast(Array4d, self.reshape(array, (d0, d1, d2, d3))) + def reshape1f(self, array: FloatsXd, d0: int) -> Floats1d: return cast(Floats1d, self.reshape(array, (d0,))) @@ -619,7 +699,7 @@ def asarray( return self.xp.asarray(data, dtype=dtype) elif hasattr(data, "numpy"): # Handles PyTorch Tensor - return data.numpy() # type: ignore + return data.numpy() # type: ignore[union-attr] elif dtype is not None: return self.xp.array(data, dtype=dtype) else: @@ -641,8 +721,8 @@ def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType: if inplace: self.xp.exp(-X, out=X) - X += 1.0 # type: ignore - X **= -1.0 # type: ignore + X += 1.0 # type: ignore[assignment] + X **= -1.0 # type: ignore[assignment] return cast(FloatsType, X) else: return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X))) @@ -786,10 +866,10 @@ def clipped_linear( inplace: bool = False, ) -> FloatsType: if inplace: - X *= slope # type: ignore - X += offset # type: ignore + X *= slope # type: ignore[assignment] + X += offset # type: ignore[assignment] return cast(FloatsType, self.xp.clip(X, min_val, max_val, out=X)) - out = X * slope + offset # type: ignore + out = X * slope + offset # type: ignore[assignment] return cast(FloatsType, self.xp.clip(out, min_val, max_val)) def backprop_clipped_linear( @@ -840,27 +920,27 @@ def backprop_hard_tanh( def swish(self, X: FloatsType, inplace: bool = False) -> FloatsType: if inplace: - X *= self.sigmoid(X) # type: ignore + X *= self.sigmoid(X) # type: ignore[operator, assignment] return cast(FloatsType, X) - out = X * self.sigmoid(X) # type: ignore + out = X * self.sigmoid(X) # type: ignore[operator] return cast(FloatsType, out) def backprop_swish( self, dY: FloatsType, X: FloatsType, Y: FloatsType, inplace: bool = False ) -> FloatsType: - Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore + Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore[operator] if inplace: - dY *= Y # type: ignore + dY *= Y # type: ignore[operator, assignment] return cast(FloatsType, dY) - out = dY * Y # type: ignore + out = dY * Y # type: ignore[operator] return cast(FloatsType, out) # Following https://www.scitepress.org/Papers/2019/74696/74696.pdf def hard_swish(self, X: FloatsType, inplace: bool = False) -> FloatsType: if inplace: - X *= self.hard_sigmoid(X) # type: ignore + X *= self.hard_sigmoid(X) # type: ignore[operator, assignment] return cast(FloatsType, X) - out = X * self.hard_sigmoid(X) # type: ignore + out = X * self.hard_sigmoid(X) # type: ignore[operator] return cast(FloatsType, out) def backprop_hard_swish( @@ -927,7 +1007,7 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType: else: Y = self.xp.array(X) Y *= tmp - return cast(FloatsType, Y) + return Y def backprop_gelu_approx( self, dY: FloatsType, X: FloatsType, inplace: bool = False @@ -949,15 +1029,15 @@ def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType: # GELU(x) = x · Φ(x) cdf = gaussian_cdf(self, X) if inplace: - X *= cdf # type: ignore + X *= cdf # type: ignore[operator, assignment] return X - return X * cdf # type: ignore + return X * cdf # type: ignore[operator, return-value] def backprop_gelu( self, dY: FloatsType, X: FloatsType, inplace: bool = False ) -> FloatsType: # GELU'(x) = Φ(x) + x · PDF(x) - dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore + dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore[operator] if inplace: dY *= dX return dY @@ -1239,8 +1319,8 @@ def lstm_forward_training( for d in range(dirs): # The inits are shaped (depth, dirs, nO). We add the internal dimension # to make them set correctly. - Yt2 = h_init[i, d].reshape((1, nO)) # type: ignore - Ct2 = c_init[i, d].reshape((1, nO)) # type: ignore + Yt2 = h_init[i, d].reshape((1, nO)) # type: ignore[assignment] + Ct2 = c_init[i, d].reshape((1, nO)) # type: ignore[assignment] layer_params, params_i = _split_weights(params, i, nO, nI, params_i) Wx, Wh, bias = _transpose_weights(layer_params) G[i, d] += xp.dot(X, Wx.T) diff --git a/thinc/config.py b/thinc/config.py index 167f8fd97..837f91b76 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type +from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping from typing import Iterable, Sequence, cast from types import GeneratorType from dataclasses import dataclass @@ -550,7 +550,7 @@ def __init__( self, *, config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None, - errors: Iterable[Dict[str, Any]] = tuple(), + errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(), title: Optional[str] = "Config validation error", desc: Optional[str] = None, parent: Optional[str] = None, @@ -560,9 +560,10 @@ def __init__( config (Union[Config, Dict[str, Dict[str, Any]], str]): The config the validation error refers to. - errors (Iterable[Dict[str, Any]]): A list of errors as dicts with keys - "loc" (list of strings describing the path of the value), "msg" - (validation message to show) and optional "type" (mostly internals). + errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]): + A list of errors as dicts with keys "loc" (list of strings + describing the path of the value), "msg" (validation message + to show) and optional "type" (mostly internals). Same format as produced by pydantic's validation error (e.errors()). title (str): The error title. desc (str): Optional error description, displayed below the title. diff --git a/thinc/layers/array_getitem.py b/thinc/layers/array_getitem.py index 87a62e9b8..17ffcb7ee 100644 --- a/thinc/layers/array_getitem.py +++ b/thinc/layers/array_getitem.py @@ -1,13 +1,14 @@ -from typing import Union, Sequence, Tuple +from typing import Union, Sequence, Tuple, TypeVar from ..types import ArrayXd, FloatsXd, IntsXd from ..model import Model AxisIndex = Union[int, slice, Sequence[int]] Index = Union[AxisIndex, Tuple[AxisIndex, ...]] +ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd) -def array_getitem(index: Index) -> Model[ArrayXd, ArrayXd]: +def array_getitem(index: Index) -> Model[ArrayTXd, ArrayTXd]: """Index into input arrays, and return the subarrays. index: diff --git a/thinc/layers/cauchysimilarity.py b/thinc/layers/cauchysimilarity.py index 89c70078b..25af8d9df 100644 --- a/thinc/layers/cauchysimilarity.py +++ b/thinc/layers/cauchysimilarity.py @@ -23,14 +23,15 @@ def CauchySimilarity(nI: Optional[int] = None) -> Model[InT, OutT]: params={"W": None}, ) + def forward( model: Model[InT, OutT], X1_X2: InT, is_train: bool ) -> Tuple[OutT, Callable]: X1, X2 = X1_X2 W = cast(Floats2d, model.get_param("W")) diff = X1 - X2 - square_diff = diff ** 2 - total = (W * square_diff).sum(axis=1) # type: ignore + square_diff = diff**2 + total = (W * square_diff).sum(axis=1) sim, bp_sim = inverse(total) def backprop(d_sim: OutT) -> InT: diff --git a/thinc/layers/chain.py b/thinc/layers/chain.py index 324319b61..258ee0902 100644 --- a/thinc/layers/chain.py +++ b/thinc/layers/chain.py @@ -1,4 +1,4 @@ -from typing import Tuple, Callable, Optional, TypeVar, Any, Dict +from typing import Tuple, Callable, Optional, TypeVar, Any, Dict, List, cast from ..model import Model from ..config import registry @@ -7,9 +7,8 @@ InT = TypeVar("InT") -OutT = TypeVar("OutT") MidT = TypeVar("MidT") - +OutT = TypeVar("OutT") # Keep this function so we can provide variable arguments via the config @registry.layers("chain.v1") @@ -18,29 +17,31 @@ def chain_no_types(*layer: Model) -> Model: def chain( - layer1: Model[InT, MidT], layer2: Model[MidT, OutT], *layers: Model + layer1: Model[InT, MidT], layer2: Model[MidT, Any], *layers: Model[Any, Any] ) -> Model[InT, XY_YZ_OutT]: """Compose two models `f` and `g` such that they become layers of a single feed-forward model that computes `g(f(x))`. Also supports chaining more than 2 layers. + Note that the type checking for additional layers is carried out by the Thinc Mypy plugin. """ - layers = (layer1, layer2) + layers + all_layers: List[Model[Any, Any]] = [layer1, layer2] + all_layers.extend(layers) dims: Dict[str, Optional[int]] = {"nO": None} # set input dimension only if first layer has one - should be "False" otherwise - if layers[0].has_dim("nI") is True: - dims["nI"] = layers[0].get_dim("nI") - if layers[0].has_dim("nI") is None: + if all_layers[0].has_dim("nI") is True: + dims["nI"] = all_layers[0].get_dim("nI") + if all_layers[0].has_dim("nI") is None: dims["nI"] = None # set output dimension according to last layer - if layers[-1].has_dim("nO") is True: - dims["nO"] = layers[-1].get_dim("nO") + if all_layers[-1].has_dim("nO") is True: + dims["nO"] = all_layers[-1].get_dim("nO") - model: Model[InT, Any] = Model( - ">>".join(layer.name for layer in layers), + model: Model[InT, XY_YZ_OutT] = Model( + ">>".join(layer.name for layer in all_layers), forward, init=init, dims=dims, - layers=layers, + layers=all_layers, ) return model @@ -65,7 +66,9 @@ def backprop(dY: OutT) -> InT: def init( - model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None + model: Model[InT, OutT], + X: Optional[InT] = None, + Y: Optional[OutT] = None, ) -> None: if X is None and Y is None: for layer in model.layers: @@ -92,10 +95,9 @@ def init( model.set_dim("nI", model.layers[0].get_dim("nI")) if model.has_dim("nO") is None: try: - nO = get_width(curr_input) # type: ignore + nO = get_width(curr_input) # type: ignore[arg-type] + model.set_dim("nO", nO) except ValueError: if model.layers[-1].has_dim("nO"): nO = model.layers[-1].get_dim("nO") - else: - nO = None # type: ignore - model.set_dim("nO", nO) + model.set_dim("nO", nO) diff --git a/thinc/layers/concatenate.py b/thinc/layers/concatenate.py index c9faefd63..78e4c558b 100644 --- a/thinc/layers/concatenate.py +++ b/thinc/layers/concatenate.py @@ -1,5 +1,5 @@ -from typing import Any, List, Tuple, Callable, Optional, TypeVar, cast, Dict, Union - +from typing import Any, List, Tuple, Callable, Optional +from typing import TypeVar, cast, Dict, Union, Sequence from ..model import Model from ..config import registry from ..types import Array2d, Ragged @@ -9,7 +9,7 @@ InT = TypeVar("InT", bound=Any) -OutT = TypeVar("OutT", bound=Union[Array2d, List[Array2d], Ragged]) +OutT = TypeVar("OutT", bound=Union[Array2d, Sequence[Array2d], Ragged]) @registry.layers("concatenate.v1") @@ -43,15 +43,18 @@ def concatenate(*layers: Model) -> Model[InT, XY_XY_OutT]: def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]: Ys, callbacks = zip(*[layer(X, is_train=is_train) for layer in model.layers]) if isinstance(Ys[0], list): - return _list_forward(model, X, Ys, callbacks, is_train) # type: ignore + data_l, backprop = _list_forward(model, X, Ys, callbacks, is_train) + return cast(OutT, data_l), backprop elif isinstance(Ys[0], Ragged): - return _ragged_forward(model, X, Ys, callbacks, is_train) # type: ignore + data_r, backprop = _ragged_forward(model, X, Ys, callbacks, is_train) + return cast(OutT, data_r), backprop else: - return _array_forward(model, X, Ys, callbacks, is_train) # type: ignore + data_a, backprop = _array_forward(model, X, Ys, callbacks, is_train) + return cast(OutT, data_a), backprop def _array_forward( - model: Model[InT, Array2d], X, Ys, callbacks, is_train: bool + model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool ) -> Tuple[Array2d, Callable]: widths = [Y.shape[1] for Y in Ys] output = model.ops.xp.hstack(Ys) @@ -61,7 +64,9 @@ def backprop(d_output: Array2d) -> InT: dX = callbacks[0](dY) start = widths[0] add_gradients = hasattr(dX, "__add__") or hasattr(dX, "__iadd__") - add_gradients_data = hasattr(dX, "data") and (hasattr(dX.data, "__add__") or hasattr(dX.data, "__iadd__")) + add_gradients_data = hasattr(dX, "data") and ( + hasattr(dX.data, "__add__") or hasattr(dX.data, "__iadd__") + ) for bwd, width in zip(callbacks[1:], widths[1:]): dY = model.ops.as_contig(d_output[:, start : start + width]) gradient = bwd(dY) @@ -76,7 +81,7 @@ def backprop(d_output: Array2d) -> InT: def _ragged_forward( - model: Model[InT, Ragged], X, Ys, callbacks, is_train: bool + model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool ) -> Tuple[Ragged, Callable]: widths = [Y.dataXd.shape[1] for Y in Ys] @@ -98,28 +103,28 @@ def backprop(d_output: Ragged) -> InT: return output, backprop -def _list_forward(model: Model[InT, List[Array2d]], X, Ys, callbacks, is_train: bool): - lengths = model.ops.asarray1i([len(x) for x in X]) - Ys = [model.ops.xp.concatenate(Y, axis=0) for Y in Ys] - widths = [Y.shape[1] for Y in Ys] - out_array = model.ops.xp.hstack(Ys) - output = model.ops.unflatten(out_array, lengths) - - def backprop(d_output: List[Array2d]) -> InT: +def _list_forward( + model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool +) -> Tuple[Sequence[Array2d], Callable]: + def backprop(d_output: Sequence[Array2d]) -> InT: d_out_array = model.ops.xp.concatenate(d_output, axis=0) dY = model.ops.as_contig(d_out_array[:, : widths[0]]) # We want to generalize unflatten later. - dY = model.ops.unflatten(dY, lengths) # type: ignore + dY = model.ops.unflatten(dY, lengths) dX = callbacks[0](dY) start = widths[0] for bwd, width in zip(callbacks[1:], widths[1:]): dY = model.ops.as_contig(d_out_array[:, start : start + width]) - dY = model.ops.unflatten(dY, lengths) # type: ignore + dY = model.ops.unflatten(dY, lengths) dX += bwd(dY) start += width return dX - return output, backprop + lengths = model.ops.asarray1i([len(x) for x in X]) + Ys = [model.ops.xp.concatenate(Y, axis=0) for Y in Ys] + widths = [Y.shape[1] for Y in Ys] + out_array = model.ops.xp.hstack(Ys) + return model.ops.unflatten(out_array, lengths), backprop def init( diff --git a/thinc/layers/dropout.py b/thinc/layers/dropout.py index fedc3310b..f4fa29445 100644 --- a/thinc/layers/dropout.py +++ b/thinc/layers/dropout.py @@ -1,12 +1,11 @@ -from typing import Tuple, Callable, List, TypeVar, Any +from typing import Tuple, Callable, List, TypeVar, cast, Union, Sequence from ..model import Model from ..config import registry from ..types import ArrayXd, Ragged, Padded -InT = TypeVar("InT") -ArrayT = TypeVar("ArrayT", bound=ArrayXd) +InT = TypeVar("InT", bound=Union[ArrayXd, Sequence[ArrayXd], Ragged, Padded]) @registry.layers("Dropout.v1") @@ -18,40 +17,39 @@ def Dropout(rate: float = 0.0) -> Model[InT, InT]: return Model("dropout", forward, attrs={"dropout_rate": rate, "is_enabled": True}) -# We're getting type hell here, I think because of the instance checks? -# It's sort of painful, because I think this confused the types of other -# layers that are trying to use dropout. -# I've relaxed the types for now, but it'd be good to understand what's wrong -# here. -def forward(model: Model, X, is_train: bool) -> Tuple[Any, Callable]: +def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]: rate = model.attrs["dropout_rate"] is_enabled = model.attrs["is_enabled"] and is_train if rate == 0 or not is_enabled: return X, lambda dY: dY elif isinstance(X, Ragged): - return _dropout_ragged(model, X, is_train) + data_r, backprop = _dropout_ragged(model, X, is_train) + return cast(InT, data_r), backprop elif isinstance(X, Padded): - return _dropout_padded(model, X, is_train) - elif isinstance(X, list): - return _dropout_lists(model, X, is_train) + data_p, backprop = _dropout_padded(model, X, is_train) + return cast(InT, data_p), backprop + elif isinstance(X, Sequence): + data_l, backprop = _dropout_lists(model, X, is_train) + return cast(InT, data_l), backprop else: - return _dropout_array(model, X, is_train) + data_a, backprop = _dropout_array(model, cast(ArrayXd, X), is_train) + return cast(InT, data_a), backprop def _dropout_array( - model: Model[ArrayT, ArrayT], X: ArrayT, is_train: bool -) -> Tuple[ArrayT, Callable]: + model: Model[InT, InT], X: ArrayXd, is_train: bool +) -> Tuple[ArrayXd, Callable]: rate = model.attrs["dropout_rate"] mask = model.ops.get_dropout_mask(X.shape, rate) - def backprop(dY: ArrayT) -> ArrayT: + def backprop(dY: ArrayXd) -> ArrayXd: return dY * mask - return X * mask, backprop + return cast(ArrayXd, X * mask), backprop def _dropout_padded( - model: Model, Xp: Padded, is_train: bool + model: Model[InT, InT], Xp: Padded, is_train: bool ) -> Tuple[Padded, Callable]: X = Xp.data mask = model.ops.get_dropout_mask(X.shape, model.attrs["dropout_rate"]) @@ -64,7 +62,7 @@ def backprop(dYp: Padded) -> Padded: def _dropout_ragged( - model: Model, Xr: Ragged, is_train: bool + model: Model[InT, InT], Xr: Ragged, is_train: bool ) -> Tuple[Ragged, Callable]: X = Xr.data lengths = Xr.lengths @@ -78,13 +76,13 @@ def backprop(dYr: Ragged) -> Ragged: def _dropout_lists( - model: Model[ArrayT, ArrayT], Xs: List[ArrayT], is_train: bool -) -> Tuple[List[ArrayT], Callable]: + model: Model[InT, InT], Xs: Sequence[ArrayXd], is_train: bool +) -> Tuple[Sequence[ArrayXd], Callable]: rate = model.attrs["dropout_rate"] masks = [model.ops.get_dropout_mask(X.shape, rate) for X in Xs] Ys = [X * mask for X, mask in zip(Xs, masks)] - def backprop(dYs: List[ArrayT]) -> List[ArrayT]: + def backprop(dYs: List[ArrayXd]) -> List[ArrayXd]: return [dY * mask for dY, mask in zip(dYs, masks)] return Ys, backprop diff --git a/thinc/layers/embed.py b/thinc/layers/embed.py index 80b25266d..9e3587460 100644 --- a/thinc/layers/embed.py +++ b/thinc/layers/embed.py @@ -1,4 +1,4 @@ -from typing import Dict, Callable, Tuple, Optional, Union, cast +from typing import Dict, Callable, Tuple, Optional, Union, cast, TypeVar from .chain import chain from .array_getitem import ints_getitem @@ -9,7 +9,7 @@ from ..util import get_width, partial -InT = Union[Ints1d, Ints2d] +InT = TypeVar("InT", bound=Union[Ints1d, Ints2d]) OutT = Floats2d @@ -26,7 +26,7 @@ def Embed( attrs: Dict[str, Union[None, int, float]] = {} if dropout is not None: attrs["dropout_rate"] = dropout - model = Model( # type: ignore + model: Model = Model( "embed", forward, init=partial(init, initializer), @@ -45,7 +45,7 @@ def Embed( def forward( - model: Model[InT, OutT], ids: Ints1d, is_train: bool + model: Model[Ints1d, OutT], ids: Ints1d, is_train: bool ) -> Tuple[OutT, Callable]: vectors = cast(Floats2d, model.get_param("E")) nO = vectors.shape[1] @@ -72,7 +72,7 @@ def backprop(d_output: OutT) -> Ints1d: def init( initializer: Callable, - model: Model[InT, OutT], + model: Model[Ints1d, OutT], X: Optional[Ints1d] = None, Y: Optional[OutT] = None, ) -> None: diff --git a/thinc/layers/expand_window.py b/thinc/layers/expand_window.py index 15b4534f5..1075a49a2 100644 --- a/thinc/layers/expand_window.py +++ b/thinc/layers/expand_window.py @@ -24,7 +24,7 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab def _expand_window_floats( - model: Model[Floats2d, Floats2d], X: Floats2d + model: Model[InT, InT], X: Floats2d ) -> Tuple[Floats2d, Callable]: nW = model.attrs["window_size"] if len(X) > 0: @@ -40,7 +40,7 @@ def backprop(dY: Floats2d) -> Floats2d: def _expand_window_ragged( - model: Model[Ragged, Ragged], Xr: Ragged + model: Model[InT, InT], Xr: Ragged ) -> Tuple[Ragged, Callable]: nW = model.attrs["window_size"] Y = Ragged( diff --git a/thinc/layers/hashembed.py b/thinc/layers/hashembed.py index 1830aff93..74b85c7cf 100644 --- a/thinc/layers/hashembed.py +++ b/thinc/layers/hashembed.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Tuple, Optional, Any, Union, cast +from typing import Callable, Dict, Tuple, Optional, Any, Union, cast, TypeVar from .chain import chain from .array_getitem import ints_getitem @@ -9,7 +9,7 @@ from ..util import partial -InT = Union[Ints2d, Ints1d] +InT = TypeVar("InT", bound=Union[Ints1d, Ints2d]) OutT = Floats2d @@ -35,7 +35,7 @@ def HashEmbed( attrs: Dict[str, Any] = {"column": column, "seed": seed} if dropout is not None: attrs["dropout_rate"] = dropout - model = Model( # type: ignore + model: Model = Model( "hashembed", forward, init=partial(init, initializer), @@ -56,7 +56,7 @@ def HashEmbed( def forward( - model: Model[InT, OutT], ids: Ints1d, is_train: bool + model: Model[Ints1d, OutT], ids: Ints1d, is_train: bool ) -> Tuple[OutT, Callable]: vectors = cast(Floats2d, model.get_param("E")) nV = vectors.shape[0] @@ -64,7 +64,7 @@ def forward( if len(ids) == 0: output: Floats2d = model.ops.alloc((0, nO), dtype=vectors.dtype) else: - ids = model.ops.as_contig(ids, dtype="uint64") # type: ignore + ids = model.ops.as_contig(ids, dtype="uint64") nN = ids.shape[0] seed: int = model.attrs["seed"] keys = model.ops.hash(ids, seed) % nV @@ -92,7 +92,7 @@ def backprop(d_vectors: OutT) -> Ints1d: def init( initializer: Callable, - model: Model[InT, OutT], + model: Model[Ints1d, OutT], X: Optional[Ints1d] = None, Y: Optional[OutT] = None, ) -> None: diff --git a/thinc/layers/list2array.py b/thinc/layers/list2array.py index 2b4b03141..fff5befc0 100644 --- a/thinc/layers/list2array.py +++ b/thinc/layers/list2array.py @@ -1,12 +1,12 @@ -from typing import Tuple, Callable +from typing import Tuple, Callable, TypeVar, List, Union, cast from ..model import Model from ..config import registry -from ..types import Array2d, List2d +from ..types import Array2d -InT = List2d -OutT = Array2d +OutT = TypeVar("OutT", bound=Array2d) +InT = List[OutT] @registry.layers("list2array.v1") @@ -22,6 +22,6 @@ def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Cal lengths = model.ops.asarray1i([len(x) for x in Xs]) def backprop(dY: OutT) -> InT: - return model.ops.unflatten(dY, lengths) # type: ignore + return model.ops.unflatten(dY, lengths) - return model.ops.flatten(Xs), backprop # type: ignore + return model.ops.flatten(Xs), backprop diff --git a/thinc/layers/list2padded.py b/thinc/layers/list2padded.py index 28aabace7..2a02f90e0 100644 --- a/thinc/layers/list2padded.py +++ b/thinc/layers/list2padded.py @@ -1,11 +1,11 @@ -from typing import Tuple, Callable +from typing import Tuple, Callable, TypeVar, cast from ..types import Padded, List2d from ..model import Model from ..config import registry -InT = List2d +InT = TypeVar("InT", bound=List2d) OutT = Padded @@ -16,9 +16,9 @@ def list2padded() -> Model[InT, OutT]: def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]: - Yp = model.ops.list2padded(Xs) # type: ignore + Yp = model.ops.list2padded(Xs) def backprop(dYp: OutT) -> InT: - return model.ops.padded2list(dYp) # type: ignore + return cast(InT, model.ops.padded2list(dYp)) return Yp, backprop diff --git a/thinc/layers/list2ragged.py b/thinc/layers/list2ragged.py index 7b293fa3d..a63237dfe 100644 --- a/thinc/layers/list2ragged.py +++ b/thinc/layers/list2ragged.py @@ -1,11 +1,11 @@ -from typing import Tuple, List, Callable +from typing import Tuple, List, Callable, cast, TypeVar from ..model import Model from ..config import registry -from ..types import ArrayXd, Ragged +from ..types import ListXd, ArrayXd, Ragged -InT = List[ArrayXd] +InT = TypeVar("InT", bound=ListXd) OutT = Ragged @@ -20,7 +20,7 @@ def list2ragged() -> Model[InT, OutT]: def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]: def backprop(dYr: OutT) -> InT: - return model.ops.unflatten(dYr.data, dYr.lengths) # type: ignore + return cast(InT, model.ops.unflatten(dYr.data, dYr.lengths)) lengths = model.ops.asarray1i([len(x) for x in Xs]) return Ragged(model.ops.flatten(Xs), lengths), backprop diff --git a/thinc/layers/lstm.py b/thinc/layers/lstm.py index 626d2f567..2acc70ac3 100644 --- a/thinc/layers/lstm.py +++ b/thinc/layers/lstm.py @@ -45,13 +45,13 @@ def PyTorchLSTM( from .pytorchwrapper import PyTorchRNNWrapper if depth == 0: - return noop() # type: ignore + return noop() # type: ignore[misc] nH = nO if bi: nH = nO // 2 pytorch_rnn = PyTorchRNNWrapper( - torch.nn.LSTM(nI, nH, depth, bidirectional=bi, dropout=dropout) - ) + torch.nn.LSTM(nI, nH, depth, bidirectional=bi, dropout=dropout) + ) pytorch_rnn.set_dim("nO", nO) pytorch_rnn.set_dim("nI", nI) return with_padded(pytorch_rnn) @@ -161,7 +161,7 @@ def _padded_to_packed(ops: Ops, Xp: Padded) -> Ragged: start = 0 for t in range(Xp.size_at_t.shape[0]): batch_size = Xp.size_at_t[t] - Y[start : start + batch_size] = Xp.data[t, :batch_size] + Y[start : start + batch_size] = Xp.data[t, :batch_size] # type: ignore[assignment] start += batch_size return Ragged(Y, Xp.size_at_t) diff --git a/thinc/layers/maxout.py b/thinc/layers/maxout.py index 6104e0c0e..4f361f78d 100644 --- a/thinc/layers/maxout.py +++ b/thinc/layers/maxout.py @@ -54,7 +54,7 @@ def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Call def backprop(d_best: OutT) -> InT: dZ = model.ops.backprop_maxout(d_best, which, nP) # TODO: Add sum methods for Floats3d - model.inc_grad("b", dZ.sum(axis=0)) # type: ignore + model.inc_grad("b", dZ.sum(axis=0)) # type: ignore[call-overload] dY = model.ops.reshape2f(dZ, dZ.shape[0], nO * nP) dW = model.ops.reshape3f(model.ops.gemm(dY, X, trans1=True), nO, nP, nI) model.inc_grad("W", dW) diff --git a/thinc/layers/padded2list.py b/thinc/layers/padded2list.py index 42a6e4c26..8f1bee7e8 100644 --- a/thinc/layers/padded2list.py +++ b/thinc/layers/padded2list.py @@ -1,4 +1,4 @@ -from typing import Tuple, Callable +from typing import Tuple, Callable, TypeVar, cast from ..types import Padded, List2d from ..model import Model @@ -6,7 +6,7 @@ InT = Padded -OutT = List2d +OutT = TypeVar("OutT", bound=List2d) @registry.layers("padded2list.v1") @@ -15,11 +15,13 @@ def padded2list() -> Model[InT, OutT]: return Model(f"padded2list", forward) -def forward(model: Model[InT, OutT], Xp: InT, is_train: bool) -> Tuple[OutT, Callable]: - Ys = model.ops.padded2list(Xp) # type: ignore +def forward( + model: Model[InT, OutT], Xp: InT, is_train: bool +) -> Tuple[OutT, Callable[[OutT], InT]]: + Ys = cast(OutT, model.ops.padded2list(Xp)) def backprop(dYs: OutT) -> InT: - dYp = model.ops.list2padded(dYs) # type: ignore + dYp = model.ops.list2padded(dYs) assert isinstance(dYp, Padded) return dYp diff --git a/thinc/layers/ragged2list.py b/thinc/layers/ragged2list.py index 2c5f321c5..35af28f2f 100644 --- a/thinc/layers/ragged2list.py +++ b/thinc/layers/ragged2list.py @@ -1,12 +1,12 @@ -from typing import Tuple, Callable +from typing import Tuple, Callable, TypeVar, cast from ..model import Model from ..config import registry -from ..types import Ragged, List2d +from ..types import Ragged, ListXd InT = Ragged -OutT = List2d +OutT = TypeVar("OutT", bound=ListXd) @registry.layers("ragged2list.v1") @@ -19,7 +19,8 @@ def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Cal lengths = Xr.lengths def backprop(dXs: OutT) -> InT: - return Ragged(model.ops.flatten(dXs, pad=0), lengths) # type: ignore + return Ragged(model.ops.flatten(dXs, pad=0), lengths) # type:ignore[arg-type] + # type ignore necessary for older versions of Mypy/Pydantic - data = model.ops.unflatten(Xr.dataXd, Xr.lengths) # type: ignore + data = cast(OutT, model.ops.unflatten(Xr.dataXd, Xr.lengths)) return data, backprop diff --git a/thinc/layers/reduce_first.py b/thinc/layers/reduce_first.py index bfc23f613..df7541315 100644 --- a/thinc/layers/reduce_first.py +++ b/thinc/layers/reduce_first.py @@ -7,17 +7,20 @@ OutT = TypeVar("OutT", bound=ArrayXd) + @registry.layers("reduce_first.v1") def reduce_first() -> Model[Ragged, OutT]: """Reduce ragged-formatted sequences to their first element.""" return Model("reduce_first", forward) -def forward(model: Model[Ragged, OutT], Xr: Ragged, is_train: bool) -> Tuple[OutT, Callable]: +def forward( + model: Model[Ragged, OutT], Xr: Ragged, is_train: bool +) -> Tuple[OutT, Callable[[OutT], Ragged]]: starts = model.ops.alloc1i(Xr.lengths.shape[0]) starts[1:] += Xr.lengths.cumsum()[:-1] - X = cast(OutT, Xr.dataXd) - Y = cast(OutT, X[starts]) # type: ignore + X = Xr.dataXd + Y = cast(OutT, X[starts]) x_shape = Xr.dataXd.shape lengths = Xr.lengths @@ -25,8 +28,8 @@ def forward(model: Model[Ragged, OutT], Xr: Ragged, is_train: bool) -> Tuple[Out def backprop(dY: OutT) -> Ragged: array_info.check_consistency(dY) - dX = cast(OutT, model.ops.alloc(x_shape, dtype=dY.dtype)) - dX[starts] = dY # type: ignore + dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype) + dX[starts] = dY # type: ignore[assignment] return Ragged(dX, lengths) return Y, backprop diff --git a/thinc/layers/reduce_last.py b/thinc/layers/reduce_last.py index 51bda3690..e45a65d12 100644 --- a/thinc/layers/reduce_last.py +++ b/thinc/layers/reduce_last.py @@ -7,23 +7,26 @@ OutT = TypeVar("OutT", bound=ArrayXd) + @registry.layers("reduce_last.v1") def reduce_last() -> Model[Ragged, OutT]: """Reduce ragged-formatted sequences to their last element.""" return Model("reduce_last", forward) -def forward(model: Model[Ragged, OutT], Xr: Ragged, is_train: bool) -> Tuple[OutT, Callable]: +def forward( + model: Model[Ragged, OutT], Xr: Ragged, is_train: bool +) -> Tuple[OutT, Callable[[OutT], Ragged]]: ends = Xr.lengths.cumsum() - 1 - Y = cast(OutT, Xr.dataXd[ends]) # type: ignore + Y = cast(OutT, Xr.dataXd[ends]) x_shape = Xr.dataXd.shape lengths = Xr.lengths array_info = ArrayInfo.from_array(Y) def backprop(dY: OutT) -> Ragged: array_info.check_consistency(dY) - dX = cast(OutT, model.ops.alloc(x_shape, dtype=dY.dtype)) - dX[ends] = dY # type: ignore + dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype) + dX[ends] = dY # type: ignore[assignment] return Ragged(dX, lengths) return Y, backprop diff --git a/thinc/layers/residual.py b/thinc/layers/residual.py index f4062cad1..3793ee1d5 100644 --- a/thinc/layers/residual.py +++ b/thinc/layers/residual.py @@ -4,9 +4,10 @@ from ..config import registry from ..types import Floats1d, Floats2d, Floats3d, Floats4d, FloatsXd, Ragged, Padded - # fmt: off -InT = TypeVar("InT", List[Floats1d], List[Floats2d], List[Floats3d], List[Floats4d], Ragged, Padded, FloatsXd) +InT = TypeVar( + "InT", List[Floats1d], List[Floats2d], List[Floats3d], List[Floats4d], + Ragged, Padded, FloatsXd, Floats1d, Floats2d, Floats3d, Floats4d) # fmt: on diff --git a/thinc/layers/resizable.py b/thinc/layers/resizable.py index 3459296ce..3454684d0 100644 --- a/thinc/layers/resizable.py +++ b/thinc/layers/resizable.py @@ -10,8 +10,7 @@ @registry.layers("resizable.v1") def resizable(layer, resize_layer: Callable) -> Model[InT, OutT]: - """Container that holds one layer that can change dimensions. - """ + """Container that holds one layer that can change dimensions.""" return Model( f"resizable({layer.name})", forward, diff --git a/thinc/layers/siamese.py b/thinc/layers/siamese.py index cc3c2ca71..82bafacbb 100644 --- a/thinc/layers/siamese.py +++ b/thinc/layers/siamese.py @@ -9,7 +9,7 @@ LayerT = TypeVar("LayerT") SimT = TypeVar("SimT") InT = Tuple[LayerT, LayerT] -OutT = ArrayXd +OutT = TypeVar("OutT", bound=ArrayXd) @registry.layers("siamese.v1") diff --git a/thinc/layers/sigmoid_activation.py b/thinc/layers/sigmoid_activation.py index 02923b6eb..8b3982aea 100644 --- a/thinc/layers/sigmoid_activation.py +++ b/thinc/layers/sigmoid_activation.py @@ -17,6 +17,8 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab Y = model.ops.sigmoid(X, inplace=False) def backprop(dY: InT) -> InT: - return dY * model.ops.dsigmoid(Y, inplace=False) # type: ignore + return cast( + InT, dY * model.ops.dsigmoid(Y, inplace=False) # type:ignore[operator] + ) return Y, backprop diff --git a/thinc/layers/tuplify.py b/thinc/layers/tuplify.py index b95082ddf..99b4d7589 100644 --- a/thinc/layers/tuplify.py +++ b/thinc/layers/tuplify.py @@ -1,15 +1,16 @@ -from typing import Callable, Optional, Tuple, Any, TypeVar +from typing import Optional, Tuple, Any, TypeVar from ..model import Model from ..config import registry InT = TypeVar("InT") OutT = Tuple -MidT = TypeVar("MidT") @registry.layers("tuplify.v1") -def tuplify(layer1: Model[InT, Any], layer2: Model[InT, Any], *layers) -> Model[InT, Tuple]: +def tuplify( + layer1: Model[InT, Any], layer2: Model[InT, Any], *layers +) -> Model[InT, Tuple]: """Send a separate copy of the input to each child layer, and join the outputs of the children into a tuple on the way out. diff --git a/thinc/layers/with_array.py b/thinc/layers/with_array.py index 3d6d91833..3701fc8a3 100644 --- a/thinc/layers/with_array.py +++ b/thinc/layers/with_array.py @@ -2,20 +2,19 @@ from ..model import Model from ..config import registry -from ..types import Array2d, Floats2d, Padded, Ragged, ArrayXd, Floats3d -from ..types import List2d +from ..types import Padded, Ragged, ArrayXd, Array3d, ListXd - -SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, List2d, ArrayXd]) +ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd) +SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, ListXd, ArrayXd]) @registry.layers("with_array.v1") -def with_array(layer: Model[ArrayXd, ArrayXd], pad: int = 0) -> Model[SeqT, SeqT]: - """Transform sequence data into a contiguous 2d array on the way into and +def with_array(layer: Model[ArrayTXd, ArrayTXd], pad: int = 0) -> Model[SeqT, SeqT]: + """Transform sequence data into a contiguous array on the way into and out of a model. Handles a variety of sequence types: lists, padded and ragged. - If the input is a 2d array, it is passed through unchanged. + If the input is an array, it is passed through unchanged. """ - return Model( + model: Model[SeqT, SeqT] = Model( f"with_array({layer.name})", forward, init=init, @@ -23,21 +22,20 @@ def with_array(layer: Model[ArrayXd, ArrayXd], pad: int = 0) -> Model[SeqT, SeqT attrs={"pad": pad}, dims={name: layer.maybe_get_dim(name) for name in layer.dim_names}, ) + return model -def forward(model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool): +def forward( + model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool +) -> Tuple[SeqT, Callable]: if isinstance(Xseq, Ragged): - return _ragged_forward( - cast(Model[Ragged, Ragged], model), cast(Ragged, Xseq), is_train - ) + return cast(Tuple[SeqT, Callable], _ragged_forward(model, Xseq, is_train)) elif isinstance(Xseq, Padded): - return _padded_forward( - cast(Model[Padded, Padded], model), cast(Padded, Xseq), is_train - ) + return cast(Tuple[SeqT, Callable], _padded_forward(model, Xseq, is_train)) elif not isinstance(Xseq, (list, tuple)): return model.layers[0](Xseq, is_train) else: - return _list_forward(cast(Model[List2d, List2d], model), Xseq, is_train) + return cast(Tuple[SeqT, Callable], _list_forward(model, Xseq, is_train)) def init( @@ -66,16 +64,16 @@ def _get_array(model, X: SeqT) -> ArrayXd: def _list_forward( - model: Model[List2d, List2d], Xs: List2d, is_train: bool -) -> Tuple[List2d, Callable]: - layer = model.layers[0] + model: Model[SeqT, SeqT], Xs: ListXd, is_train: bool +) -> Tuple[ListXd, Callable]: + layer: Model[ArrayXd, ArrayXd] = model.layers[0] pad = model.attrs["pad"] lengths = layer.ops.asarray1i([len(seq) for seq in Xs]) - Xf = layer.ops.flatten(Xs, pad=pad) # type: ignore + Xf = layer.ops.flatten(Xs, pad=pad) Yf, get_dXf = layer(Xf, is_train) - def backprop(dYs: List2d) -> List2d: - dYf = layer.ops.flatten(dYs, pad=pad) # type: ignore + def backprop(dYs: ListXd) -> ListXd: + dYf = layer.ops.flatten(dYs, pad=pad) dXf = get_dXf(dYf) return layer.ops.unflatten(dXf, lengths, pad=pad) @@ -83,7 +81,7 @@ def backprop(dYs: List2d) -> List2d: def _ragged_forward( - model: Model[Ragged, Ragged], Xr: Ragged, is_train: bool + model: Model[SeqT, SeqT], Xr: Ragged, is_train: bool ) -> Tuple[Ragged, Callable]: layer: Model[ArrayXd, ArrayXd] = model.layers[0] Y, get_dX = layer(Xr.dataXd, is_train) @@ -95,9 +93,9 @@ def backprop(dYr: Ragged) -> Ragged: def _padded_forward( - model: Model[Padded, Padded], Xp: Padded, is_train: bool + model: Model[SeqT, SeqT], Xp: Padded, is_train: bool ) -> Tuple[Padded, Callable]: - layer: Model[Floats3d, Floats3d] = model.layers[0] + layer: Model[Array3d, Array3d] = model.layers[0] Y, get_dX = layer(Xp.data, is_train) def backprop(dYp: Padded) -> Padded: diff --git a/thinc/layers/with_array2d.py b/thinc/layers/with_array2d.py index 15f4420d2..9f7de213c 100644 --- a/thinc/layers/with_array2d.py +++ b/thinc/layers/with_array2d.py @@ -1,9 +1,8 @@ -from typing import Tuple, Callable, Optional, TypeVar, Union, cast +from typing import Tuple, Callable, Optional, TypeVar, cast, List, Union from ..model import Model from ..config import registry -from ..types import Array2d, Floats2d, Padded, Ragged, ArrayXd -from ..types import List2d +from ..types import Array2d, Floats2d, List2d, Padded, Ragged ValT = TypeVar("ValT", bound=Array2d) @@ -26,19 +25,18 @@ def with_array2d(layer: Model[ValT, ValT], pad: int = 0) -> Model[SeqT, SeqT]: ) -def forward(model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool): +def forward( + model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool +) -> Tuple[SeqT, Callable]: if isinstance(Xseq, Ragged): - return _ragged_forward( - cast(Model[Ragged, Ragged], model), cast(Ragged, Xseq), is_train - ) + return cast(Tuple[SeqT, Callable], _ragged_forward(model, Xseq, is_train)) elif isinstance(Xseq, Padded): - return _padded_forward( - cast(Model[Padded, Padded], model), cast(Padded, Xseq), is_train - ) + return cast(Tuple[SeqT, Callable], _padded_forward(model, Xseq, is_train)) elif not isinstance(Xseq, (list, tuple)): return model.layers[0](Xseq, is_train) else: - return _list_forward(cast(Model[List2d, List2d], model), Xseq, is_train) + return cast(Tuple[SeqT, Callable], _list_forward(model, Xseq, is_train)) + return def init( @@ -69,16 +67,16 @@ def _get_array(model, X: SeqT) -> Array2d: def _list_forward( - model: Model[List2d, List2d], Xs: List2d, is_train: bool + model: Model[SeqT, SeqT], Xs: List2d, is_train: bool ) -> Tuple[List2d, Callable]: - layer = model.layers[0] + layer: Model[Array2d, Array2d] = model.layers[0] pad = model.attrs["pad"] lengths = layer.ops.asarray1i([len(seq) for seq in Xs]) - Xf = layer.ops.flatten(Xs, pad=pad) # type: ignore + Xf = layer.ops.flatten(Xs, pad=pad) Yf, get_dXf = layer(Xf, is_train) def backprop(dYs: List2d) -> List2d: - dYf = layer.ops.flatten(dYs, pad=pad) # type: ignore + dYf = layer.ops.flatten(dYs, pad=pad) dXf = get_dXf(dYf) return layer.ops.unflatten(dXf, lengths, pad=pad) @@ -86,23 +84,23 @@ def backprop(dYs: List2d) -> List2d: def _ragged_forward( - model: Model[Ragged, Ragged], Xr: Ragged, is_train: bool + model: Model[SeqT, SeqT], Xr: Ragged, is_train: bool ) -> Tuple[Ragged, Callable]: - layer: Model[Array2d, ArrayXd] = model.layers[0] + layer: Model[Array2d, Array2d] = model.layers[0] Y, get_dX = layer(Xr.data, is_train) x_shape = Xr.dataXd.shape def backprop(dYr: Ragged) -> Ragged: return Ragged(get_dX(dYr.dataXd).reshape(x_shape), dYr.lengths) - + return Ragged(Y, Xr.lengths), backprop def _padded_forward( - model: Model[Padded, Padded], Xp: Padded, is_train: bool + model: Model[SeqT, SeqT], Xp: Padded, is_train: bool ) -> Tuple[Padded, Callable]: layer: Model[Array2d, Array2d] = model.layers[0] - X = model.ops.reshape2f( + X = model.ops.reshape2( Xp.data, Xp.data.shape[0] * Xp.data.shape[1], Xp.data.shape[2] ) Y2d, get_dX = layer(X, is_train) @@ -112,7 +110,7 @@ def _padded_forward( def backprop(dYp: Padded) -> Padded: assert isinstance(dYp, Padded) - dY = model.ops.reshape2f( + dY = model.ops.reshape2( dYp.data, dYp.data.shape[0] * dYp.data.shape[1], dYp.data.shape[2] ) dX2d = get_dX(dY) diff --git a/thinc/layers/with_flatten.py b/thinc/layers/with_flatten.py index 94246c5d2..2ae50f282 100644 --- a/thinc/layers/with_flatten.py +++ b/thinc/layers/with_flatten.py @@ -1,34 +1,34 @@ -from typing import Tuple, Callable, Sequence, Any, List, TypeVar +from typing import Tuple, Callable, Sequence, Any, cast, TypeVar, Optional, List from ..model import Model from ..config import registry -from ..types import Array2d, List2d +from ..types import ListXd ItemT = TypeVar("ItemT") InT = Sequence[Sequence[ItemT]] -OutT = List2d +OutT = TypeVar("OutT", bound=ListXd) @registry.layers("with_flatten.v1") -def with_flatten(layer: Model) -> Model[InT, OutT]: +def with_flatten(layer: Model[InT, InT]) -> Model[OutT, OutT]: return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init) def forward( - model: Model[InT, OutT], Xnest: InT, is_train: bool + model: Model[OutT, OutT], Xnest: OutT, is_train: bool ) -> Tuple[OutT, Callable]: - layer: Model[Sequence[Any], Array2d] = model.layers[0] - Xflat: Sequence[Any] = _flatten(Xnest) + layer: Model[InT, InT] = model.layers[0] + Xflat: Sequence[Sequence[Any]] = _flatten(Xnest) Yflat, backprop_layer = layer(Xflat, is_train) # Get the split points. We want n-1 splits for n items. arr = layer.ops.asarray1i([len(x) for x in Xnest[:-1]]) splits = arr.cumsum() Ynest = layer.ops.xp.split(Yflat, splits, axis=0) - def backprop(dYnest: OutT) -> InT: - # I think the input/output types might be wrong here? - dYflat = model.ops.flatten(dYnest) # type: ignore + def backprop(dYnest: OutT) -> OutT: + dYflat = model.ops.flatten(dYnest) # type: ignore[arg-type, var-annotated] + # type ignore necessary for older versions of Mypy/Pydantic dXflat = backprop_layer(dYflat) dXnest = layer.ops.xp.split(dXflat, splits, axis=-1) return dXnest @@ -36,14 +36,16 @@ def backprop(dYnest: OutT) -> InT: return Ynest, backprop -def _flatten(nested: InT) -> List[ItemT]: - flat: List[ItemT] = [] +def _flatten(nested: OutT) -> InT: + flat: List = [] for item in nested: flat.extend(item) - return flat + return cast(InT, flat) -def init(model, X=None, Y=None) -> None: +def init( + model: Model[OutT, OutT], X: Optional[OutT] = None, Y: Optional[OutT] = None +) -> None: model.layers[0].initialize( _flatten(X) if X is not None else None, model.layers[0].ops.xp.hstack(Y) if Y is not None else None, diff --git a/thinc/layers/with_list.py b/thinc/layers/with_list.py index 6b1a51fe7..9f86c24dc 100644 --- a/thinc/layers/with_list.py +++ b/thinc/layers/with_list.py @@ -1,11 +1,10 @@ from typing import Tuple, Callable, List, Optional, TypeVar, Union, cast -from ..types import Padded, Ragged, Floats2d, List2d +from ..types import Padded, Ragged, Array2d, List2d, Floats2d, Ints2d from ..model import Model from ..config import registry - -SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, List2d]) +SeqT = TypeVar("SeqT", Padded, Ragged, List2d, List[Floats2d], List[Ints2d]) @registry.layers("with_list.v1") @@ -23,14 +22,12 @@ def forward( model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool ) -> Tuple[SeqT, Callable]: layer: Model[List2d, List2d] = model.layers[0] - Y: Union[Padded, Ragged, List2d] if isinstance(Xseq, Padded): - Y, backprop = _padded_forward(layer, cast(Padded, Xseq), is_train) + return _padded_forward(layer, Xseq, is_train) elif isinstance(Xseq, Ragged): - Y, backprop = _ragged_forward(layer, cast(Ragged, Xseq), is_train) + return _ragged_forward(layer, Xseq, is_train) else: - Y, backprop = layer(cast(List2d, Xseq), is_train) - return cast(Tuple[SeqT, Callable], (Y, backprop)) + return cast(Tuple[SeqT, Callable], layer(cast(List2d, Xseq), is_train)) def init( @@ -51,7 +48,9 @@ def _get_list(model, seq): return seq -def _ragged_forward(layer, Xr, is_train): +def _ragged_forward( + layer: Model[List2d, List2d], Xr: Ragged, is_train: bool +) -> Tuple[Ragged, Callable]: # Assign these to locals, to keep code a bit shorter. unflatten = layer.ops.unflatten flatten = layer.ops.flatten @@ -62,12 +61,17 @@ def _ragged_forward(layer, Xr, is_train): Ys, get_dXs = layer(unflatten(Xr.data, Xr.lengths), is_train) def backprop(dYr: Ragged): - return Ragged(flatten(get_dXs(unflatten(dYr.data, dYr.lengths))), dYr.lengths) + return Ragged( + flatten(get_dXs(unflatten(dYr.data, dYr.lengths))), + dYr.lengths, + ) return Ragged(flatten(Ys), Xr.lengths), backprop -def _padded_forward(layer, Xp, is_train): +def _padded_forward( + layer: Model[List2d, List2d], Xp: Padded, is_train: bool +) -> Tuple[Padded, Callable]: # Assign these to locals, to keep code a bit shorter. padded2list = layer.ops.padded2list list2padded = layer.ops.list2padded @@ -80,4 +84,4 @@ def _padded_forward(layer, Xp, is_train): def backprop(dYp): return list2padded(get_dXs(padded2list(dYp))) - return list2padded(cast(List[Floats2d], Ys)), backprop + return list2padded(Ys), backprop diff --git a/thinc/layers/with_padded.py b/thinc/layers/with_padded.py index e2e33d5f3..379df1bef 100644 --- a/thinc/layers/with_padded.py +++ b/thinc/layers/with_padded.py @@ -1,6 +1,6 @@ -from typing import Tuple, Callable, Optional, TypeVar, Union, cast +from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List -from ..types import Padded, Ragged, Array2d, Floats3d, Ints1d, Floats2d, List2d +from ..types import Padded, Ragged, Floats3d, Ints1d, List2d, Array2d from ..model import Model from ..config import registry from ..util import is_xp_array @@ -25,18 +25,23 @@ def forward( model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool ) -> Tuple[SeqT, Callable]: layer: Model[Padded, Padded] = model.layers[0] - Y: Union[Padded, Ragged, List2d, PaddedData] if isinstance(Xseq, Padded): - Y, backprop = layer(Xseq, is_train) + return cast(Tuple[SeqT, Callable], layer(Xseq, is_train)) elif isinstance(Xseq, Ragged): - Y, backprop = _ragged_forward(layer, cast(Ragged, Xseq), is_train) + return cast(Tuple[SeqT, Callable], _ragged_forward(layer, Xseq, is_train)) elif _is_padded_data(Xseq): - Y, backprop = _tuple_forward(layer, cast(PaddedData, Xseq), is_train) + return cast( + Tuple[SeqT, Callable], + _tuple_forward(layer, cast(PaddedData, Xseq), is_train), + ) elif is_xp_array(Xseq): - Y, backprop = _array_forward(layer, cast(Floats3d, Xseq), is_train) + return cast( + Tuple[SeqT, Callable], _array_forward(layer, cast(Floats3d, Xseq), is_train) + ) else: - Y, backprop = _list_forward(layer, cast(List2d, Xseq), is_train) - return cast(Tuple[SeqT, Callable], (Y, backprop)) + return cast( + Tuple[SeqT, Callable], _list_forward(layer, cast(List2d, Xseq), is_train) + ) def init( @@ -48,28 +53,31 @@ def init( ) -def _is_padded_data(seq): +def _is_padded_data(seq: SeqT) -> bool: return isinstance(seq, tuple) and len(seq) == 4 and all(map(is_xp_array, seq)) -def _get_padded(model, seq): +def _get_padded(model: Model, seq: SeqT) -> Padded: if isinstance(seq, Padded): return seq elif isinstance(seq, Ragged): return model.ops.list2padded(model.ops.unflatten(seq.data, seq.lengths)) elif _is_padded_data(seq): - return Padded(*seq) # type: ignore + return Padded(*seq) # type: ignore[misc] elif is_xp_array(seq): - size_at_t = model.ops.asarray1i([seq.shape[1]] * seq.shape[0]) - lengths = model.ops.asarray1i([seq.shape[0]] * seq.shape[1]) - indices = model.ops.xp.arange(seq.shape[1]) - return Padded(seq, size_at_t, lengths, indices) + floats3d_seq = cast(Floats3d, seq) + size_at_t = model.ops.asarray1i([floats3d_seq.shape[1]] * floats3d_seq.shape[0]) + lengths = model.ops.asarray1i([floats3d_seq.shape[0]] * floats3d_seq.shape[1]) + indices = model.ops.xp.arange(floats3d_seq.shape[1]) + return Padded(floats3d_seq, size_at_t, lengths, indices) else: assert isinstance(seq, list), seq return model.ops.list2padded(seq) -def _array_forward(layer, X, is_train): +def _array_forward( + layer: Model[Padded, Padded], X: Floats3d, is_train +) -> Tuple[Floats3d, Callable]: # Create bogus metadata for Padded. Xp = _get_padded(layer, X) Yp, get_dXp = layer(Xp, is_train) @@ -82,20 +90,24 @@ def backprop(dY: Floats3d) -> Floats3d: dXp = get_dXp(dYp) return dXp.data - return Yp.data, backprop + return cast(Floats3d, Yp.data), backprop -def _tuple_forward(layer, X, is_train: bool): +def _tuple_forward( + layer: Model[Padded, Padded], X: PaddedData, is_train: bool +) -> Tuple[PaddedData, Callable]: Yp, get_dXp = layer(Padded(*X), is_train) def backprop(dY): dXp = get_dXp(Padded(*dY)) return (dXp.data, dXp.size_at_t, dXp.lengths, dXp.indices) - return (Yp.data, Yp.size_at_t, Yp.lengths, Yp.indices), backprop + return (cast(Floats3d, Yp.data), Yp.size_at_t, Yp.lengths, Yp.indices), backprop -def _ragged_forward(layer, Xr, is_train): +def _ragged_forward( + layer: Model[Padded, Padded], Xr: Ragged, is_train: bool +) -> Tuple[Ragged, Callable]: # Assign these to locals, to keep code a bit shorter. list2padded = layer.ops.list2padded padded2list = layer.ops.padded2list @@ -109,22 +121,24 @@ def _ragged_forward(layer, Xr, is_train): def backprop(dYr: Ragged): flattened = flatten( - padded2list(get_dXp(list2padded(unflatten(dYr.data, dYr.lengths)))) + padded2list(get_dXp(list2padded(unflatten(dYr.data, dYr.lengths)))), ) - return Ragged(cast(Floats2d, flattened), dYr.lengths) + return Ragged(flattened, dYr.lengths) flattened = flatten(padded2list(Yp)) return Ragged(flattened, Xr.lengths), backprop -def _list_forward(layer, Xs, is_train): +def _list_forward( + layer: Model[Padded, Padded], Xs: List2d, is_train: bool +) -> Tuple[List2d, Callable]: # Assign these to locals, to keep code a bit shorter. list2padded = layer.ops.list2padded padded2list = layer.ops.padded2list - Yp, get_dXp = layer(list2padded(Xs), is_train) # type: ignore + Yp, get_dXp = layer(list2padded(Xs), is_train) def backprop(dYs): - return padded2list(get_dXp(list2padded(dYs))) # type: ignore + return padded2list(get_dXp(list2padded(dYs))) return padded2list(Yp), backprop diff --git a/thinc/layers/with_ragged.py b/thinc/layers/with_ragged.py index db01832f2..005c69048 100644 --- a/thinc/layers/with_ragged.py +++ b/thinc/layers/with_ragged.py @@ -1,12 +1,11 @@ -from typing import Tuple, Callable, Optional, TypeVar, Union, cast +from typing import Tuple, Callable, Optional, TypeVar, cast, List, Union -from ..types import Padded, Ragged, Ints1d, Array2d, List2d +from ..types import Padded, Ragged, Array2d, ListXd, List2d, Ints1d from ..model import Model from ..config import registry - RaggedData = Tuple[Array2d, Ints1d] -SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, List2d, RaggedData]) +SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, ListXd, RaggedData]) @registry.layers("with_ragged.v1") @@ -18,20 +17,25 @@ def forward( model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool ) -> Tuple[SeqT, Callable]: layer: Model[Ragged, Ragged] = model.layers[0] - Y: Union[Padded, Ragged, List2d, RaggedData] if isinstance(Xseq, Ragged): - Y, backprop = layer(Xseq, is_train) + return cast(Tuple[SeqT, Callable], layer(Xseq, is_train)) elif isinstance(Xseq, Padded): - Y, backprop = _padded_forward(layer, cast(Padded, Xseq), is_train) + return cast(Tuple[SeqT, Callable], _padded_forward(layer, Xseq, is_train)) elif _is_ragged_data(Xseq): - Y, backprop = _tuple_forward(layer, cast(RaggedData, Xseq), is_train) + return cast( + Tuple[SeqT, Callable], + _tuple_forward(layer, cast(RaggedData, Xseq), is_train), + ) else: - Y, backprop = _list_forward(layer, cast(List2d, Xseq), is_train) - return cast(Tuple[SeqT, Callable], (Y, backprop)) + return cast( + Tuple[SeqT, Callable], _list_forward(layer, cast(List, Xseq), is_train) + ) def init( - model: Model[SeqT, SeqT], X: Optional[SeqT] = None, Y: Optional[SeqT] = None + model: Model[SeqT, SeqT], + X: Optional[SeqT] = None, + Y: Optional[SeqT] = None, ) -> None: model.layers[0].initialize( X=_get_ragged(model, X) if X is not None else None, @@ -39,26 +43,29 @@ def init( ) - def _is_ragged_data(seq): return isinstance(seq, tuple) and len(seq) == 2 -def _get_ragged(model, seq): +def _get_ragged(model: Model[SeqT, SeqT], seq: SeqT) -> Ragged: if isinstance(seq, Ragged): return seq elif isinstance(seq, Padded): lists = model.ops.padded2list(seq) lengths = model.ops.asarray1i([len(x) for x in lists]) + k = model.ops.flatten(lists) return Ragged(model.ops.flatten(lists), lengths) elif _is_ragged_data(seq): - return Ragged(*seq) + return Ragged(*seq) # type: ignore[misc] else: - lengths = model.ops.asarray1i([len(x) for x in seq]) - return Ragged(model.ops.flatten(seq), lengths) + list2d_seq = cast(List2d, seq) + lengths = model.ops.asarray1i([len(x) for x in list2d_seq]) + return Ragged(model.ops.flatten(list2d_seq), lengths) -def _tuple_forward(layer: Model[Ragged, Ragged], X: RaggedData, is_train: bool): +def _tuple_forward( + layer: Model[Ragged, Ragged], X: RaggedData, is_train: bool +) -> Tuple[RaggedData, Callable]: Yr, get_dXr = layer(Ragged(*X), is_train) def backprop(dY: RaggedData) -> RaggedData: @@ -68,7 +75,9 @@ def backprop(dY: RaggedData) -> RaggedData: return (Yr.data, Yr.lengths), backprop -def _padded_forward(layer, Xp, is_train): +def _padded_forward( + layer: Model[Ragged, Ragged], Xp: Padded, is_train: bool +) -> Tuple[Padded, Callable]: # Assign these to locals, to keep code a bit shorter. list2padded = layer.ops.list2padded padded2list = layer.ops.padded2list @@ -86,12 +95,18 @@ def _padded_forward(layer, Xp, is_train): def backprop(dYp: Padded): flattened = flatten(padded2list(dYp)) - return list2padded(unflatten(get_dXr(Ragged(flattened, lengths)).data, lengths)) + dXr = get_dXr(Ragged(flattened, lengths)) + return list2padded(unflatten(dXr.data, lengths)) - return list2padded(unflatten(Yr.data, Yr.lengths)), backprop + return ( + list2padded(unflatten(Yr.data, Yr.lengths)), + backprop, + ) -def _list_forward(layer, Xs, is_train: bool): +def _list_forward( + layer: Model[Ragged, Ragged], Xs: List, is_train: bool +) -> Tuple[List, Callable]: # Assign these to locals, to keep code a bit shorter. flatten = layer.ops.flatten unflatten = layer.ops.unflatten diff --git a/thinc/layers/with_reshape.py b/thinc/layers/with_reshape.py index bc7dc521b..5bd3e9025 100644 --- a/thinc/layers/with_reshape.py +++ b/thinc/layers/with_reshape.py @@ -1,15 +1,16 @@ -from typing import Tuple, Callable, Optional, cast +from typing import Tuple, Callable, Optional, cast, TypeVar, List from ..model import Model from ..config import registry -from ..types import Array3d, Array2d, Floats3d +from ..types import Array3d, Array2d -InT = Array3d +InT = TypeVar("InT", bound=Array3d) +OutT = TypeVar("OutT", bound=Array2d) @registry.layers("with_reshape.v1") -def with_reshape(layer: Model[Array2d, Array2d]) -> Model[InT, InT]: +def with_reshape(layer: Model[OutT, OutT]) -> Model[InT, InT]: """Reshape data on the way into and out from a layer.""" return Model( f"with_reshape({layer.name})", @@ -26,16 +27,15 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab final_shape = list(initial_shape[:-1]) + [layer.get_dim("nO")] nB = X.shape[0] nT = X.shape[1] - X2d = cast(InT, model.ops.reshape(X, (-1, X.shape[2]))) + X2d = model.ops.reshape(X, (-1, X.shape[2])) Y2d, Y2d_backprop = layer(X2d, is_train=is_train) - Y = model.ops.reshape3f(Y2d, *final_shape) + Y = model.ops.reshape3(Y2d, *final_shape) def backprop(dY: InT) -> InT: - dY_floats = model.ops.asarray3f(cast(Floats3d, dY)) - reshaped = model.ops.reshape2f(dY_floats, nB * nT, -1) - return Y2d_backprop(model.ops.reshape3f(reshaped, *initial_shape)) + reshaped = model.ops.reshape2(dY, nB * nT, -1) + return Y2d_backprop(model.ops.reshape3(reshaped, *initial_shape)) - return Y, backprop + return cast(InT, Y), backprop def init( diff --git a/thinc/loss.py b/thinc/loss.py index 3632e93cf..990b30df1 100644 --- a/thinc/loss.py +++ b/thinc/loss.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, cast, TypeVar, Generic, Any, Union, Optional +from typing import Tuple, Sequence, cast, TypeVar, Generic, Any, Union, Optional, List from typing import Dict from .types import Floats2d, Ints1d @@ -11,7 +11,7 @@ GuessT = TypeVar("GuessT") TruthT = TypeVar("TruthT") IntsOrFloats = Union[Ints1d, Floats2d] -IntsOrFloatsOrStrs = Union[Ints1d, Floats2d, List[int], List[str]] +IntsOrFloatsOrStrs = Union[Ints1d, Floats2d, Sequence[int], Sequence[str]] class Loss(Generic[GuessT, TruthT, GradT, LossT]): # pragma: no cover @@ -35,7 +35,7 @@ def get_loss(self, guesses: GuessT, truths: TruthT) -> LossT: class CategoricalCrossentropy(Loss): - names: Optional[List[str]] + names: Optional[Sequence[str]] missing_value: Optional[Union[str, int]] _name_to_i: Dict[str, int] @@ -43,7 +43,7 @@ def __init__( self, *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, missing_value: Optional[Union[str, int]] = None, neg_prefix: Optional[str] = None, label_smoothing: float = 0.0, @@ -160,7 +160,7 @@ def _get_loss_from_grad(self, d_truth: Floats2d) -> float: def configure_CategoricalCrossentropy_v1( *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, missing_value: Optional[Union[str, int]] = None, ) -> CategoricalCrossentropy: return CategoricalCrossentropy( @@ -172,7 +172,7 @@ def configure_CategoricalCrossentropy_v1( def configure_CategoricalCrossentropy_v2( *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, missing_value: Optional[Union[str, int]] = None, neg_prefix: Optional[str] = None, ) -> CategoricalCrossentropy: @@ -188,7 +188,7 @@ def configure_CategoricalCrossentropy_v2( def configure_CategoricalCrossentropy_v3( *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, missing_value: Optional[Union[str, int]] = None, neg_prefix: Optional[str] = None, label_smoothing: float = 0.0, @@ -207,7 +207,7 @@ def __init__( self, *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, missing_value: Optional[Union[str, int]] = None, neg_prefix: Optional[str] = None, label_smoothing: float = 0.0, @@ -222,14 +222,14 @@ def __init__( self.normalize = normalize def __call__( - self, guesses: List[Floats2d], truths: List[IntsOrFloatsOrStrs] + self, guesses: Sequence[Floats2d], truths: Sequence[IntsOrFloatsOrStrs] ) -> Tuple[List[Floats2d], float]: grads = self.get_grad(guesses, truths) loss = self._get_loss_from_grad(grads) return grads, loss def get_grad( - self, guesses: List[Floats2d], truths: List[IntsOrFloatsOrStrs] + self, guesses: Sequence[Floats2d], truths: Sequence[IntsOrFloatsOrStrs] ) -> List[Floats2d]: err = "Cannot calculate SequenceCategoricalCrossentropy loss: guesses and truths must be same length" if len(guesses) != len(truths): # pragma: no cover @@ -244,11 +244,11 @@ def get_grad( return d_scores def get_loss( - self, guesses: List[Floats2d], truths: List[IntsOrFloatsOrStrs] + self, guesses: Sequence[Floats2d], truths: Sequence[IntsOrFloatsOrStrs] ) -> float: return self._get_loss_from_grad(self.get_grad(guesses, truths)) - def _get_loss_from_grad(self, grads: List[Floats2d]) -> float: + def _get_loss_from_grad(self, grads: Sequence[Floats2d]) -> float: loss = 0.0 for grad in grads: loss += self.cc._get_loss_from_grad(grad) @@ -257,7 +257,7 @@ def _get_loss_from_grad(self, grads: List[Floats2d]) -> float: @registry.losses("SequenceCategoricalCrossentropy.v1") def configure_SequenceCategoricalCrossentropy_v1( - *, normalize: bool = True, names: Optional[List[str]] = None + *, normalize: bool = True, names: Optional[Sequence[str]] = None ) -> SequenceCategoricalCrossentropy: return SequenceCategoricalCrossentropy(normalize=normalize, names=names) @@ -266,7 +266,7 @@ def configure_SequenceCategoricalCrossentropy_v1( def configure_SequenceCategoricalCrossentropy_v2( *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, neg_prefix: Optional[str] = None, ) -> SequenceCategoricalCrossentropy: return SequenceCategoricalCrossentropy( @@ -278,7 +278,7 @@ def configure_SequenceCategoricalCrossentropy_v2( def configure_SequenceCategoricalCrossentropy_v3( *, normalize: bool = True, - names: Optional[List[str]] = None, + names: Optional[Sequence[str]] = None, missing_value: Optional[Union[str, int]] = None, neg_prefix: Optional[str] = None, label_smoothing: float = 0.0, diff --git a/thinc/types.py b/thinc/types.py index f26789399..74498d159 100644 --- a/thinc/types.py +++ b/thinc/types.py @@ -42,7 +42,7 @@ List2d = Union[List["Floats2d"], List["Ints2d"]] List3d = Union[List["Floats3d"], List["Ints3d"]] List4d = Union[List["Floats4d"], List["Ints4d"]] -ListXd = Union[List["FloatsXd"], List["IntsXd"]] +ListXd = Union[List1d, List2d, List3d, List4d] ArrayT = TypeVar("ArrayT") SelfT = TypeVar("SelfT") @@ -712,7 +712,31 @@ def __get_validators__(cls): yield lambda v: validate_array(v, ndim=4, dtype="i") def __iter__(self) -> Iterator[Ints3d]: ... - # def __getitem__(self, key: int) -> Ints3d: ... + + @overload + def __getitem__(self, key: _4_KeyScalar) -> int: ... + @overload + def __getitem__(self, key: _4_Key1d) -> Ints1d: ... + @overload + def __getitem__(self, key: _4_Key2d) -> Ints2d: ... + @overload + def __getitem__(self, key: _4_Key3d) -> Ints3d: ... + @overload + def __getitem__(self, key: _4_Key4d) -> "Ints4d": ... + def __getitem__(self, key: _4_AllKeys) -> _I4_AllReturns: ... + + @overload + def __setitem__(self, key: _4_KeyScalar, value: int) -> None: ... + @overload + def __setitem__(self, key: _4_Key1d, value: Ints1d) -> None: ... + @overload + def __setitem__(self, key: _4_Key2d, value: Ints2d) -> None: ... + @overload + def __setitem__(self, key: _4_Key3d, value: Ints3d) -> None: ... + @overload + def __setitem__(self, key: _4_Key4d, value: "Ints4d") -> None: ... + + def __setitem__(self, key: _4_AllKeys, value: _I4_AllReturns) -> None: ... @overload def sum(self, *, keepdims: Tru, axis: _4_AllAx = None, out: Optional["Ints4d"] = None) -> "Ints4d": ... @@ -782,7 +806,7 @@ class Padded: and the indices indicates the original ordering. """ - data: Floats3d + data: Array3d size_at_t: Ints1d lengths: Ints1d indices: Ints1d diff --git a/thinc/util.py b/thinc/util.py index 9dfc755cc..dacfb09c2 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -1,5 +1,5 @@ from typing import Any, Union, Sequence, cast, Dict, Optional, Callable, TypeVar -from typing import List, Tuple +from typing import List, Mapping, Tuple import numpy from packaging.version import Version import random @@ -490,7 +490,11 @@ def partial( class DataValidationError(ValueError): def __init__( - self, name: str, X: Any, Y: Any, errors: List[Dict[str, Any]] = [] + self, + name: str, + X: Any, + Y: Any, + errors: Union[Sequence[Mapping[str, Any]], List[Dict[str, Any]]] = [], ) -> None: """Custom error for validating inputs / outputs at runtime.""" message = f"Data validation error in '{name}'" @@ -564,7 +568,7 @@ def data_validation(validation): @contextlib.contextmanager -def use_nvtx_range(message: int, id_color: int = -1): +def use_nvtx_range(message: str, id_color: int = -1): """Context manager to register the executed code as an NVTX range. The ranges can be used as markers in CUDA profiling.""" if has_cupy: diff --git a/website/docs/api-layers.md b/website/docs/api-layers.md index b4c51e4b8..1c43a9d7a 100644 --- a/website/docs/api-layers.md +++ b/website/docs/api-layers.md @@ -48,8 +48,10 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/cauchysimilarity.py -- **Input:** ArrayXd -- **Output:** ArrayXd +- **Input:** ArrayXd / Sequence[ArrayXd] / + Ragged / Padded +- **Output:** ArrayXd / Sequence[ArrayXd] + / Ragged / Padded - **Attrs:** `dropout_rate` float @@ -71,10 +73,10 @@ for node in model.walk(): node.attrs["dropout_rate"] = 0.5 ``` -| Argument | Type | Description | -| -------------- | -------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | -| `dropout_rate` | float | The probability of zeroing the activations (default: 0). Higher dropout rates mean more distortion. Values around `0.2` are often good. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The created dropout layer. | +| Argument | Type | Description | +| -------------- | -------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `dropout_rate` | float | The probability of zeroing the activations (default: 0). Higher dropout rates mean more distortion. Values around `0.2` are often good. | +| **RETURNS** | Model[T, T] | The created dropout layer. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/dropout.py @@ -84,8 +86,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/dropout.py -- **Input:** Ints1d / - Ints2d +- **Input:** Union[Ints1d, Ints2d] - **Output:** Floats2d - **Parameters:** E - **Attrs:** `column` int, `dropout_rate` float @@ -114,8 +115,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/embed.py -- **Input:** Ints1d / - Ints2d +- **Input:** Union[Ints1d, Ints2d] / - **Output:** Floats2d - **Parameters:** E - **Attrs:** `seed` Optional[int], `column` int, @@ -238,8 +238,8 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/sigmoid.py -- **Input:** Floats2d -- **Output:** Floats2d +- **Input:** FloatsXd +- **Output:** FloatsXd @@ -252,7 +252,7 @@ element of the output vectors will be between `0` and `1`. | **RETURNS** | Model[Floats2d, Floats2d] | The created `sigmoid_activation` layer. | ```python -https://github.com/explosion/thinc/blob/master/thinc/layers/sigmoid_logistic.py +https://github.com/explosion/thinc/blob/master/thinc/layers/sigmoid_activation.py ``` ### LSTM and BiLSTM {#lstm tag="function"} @@ -779,34 +779,6 @@ of the lengths should equal the length of the keys and values array. https://github.com/explosion/thinc/blob/master/thinc/layers/sparselinear.pyx ``` -### StaticVectors {#staticvectors tag="function"} - - - -- **Input:** Ints2d -- **Output:** Floats2d -- **Attrs:** `column` int, `vectors` Optional[Floats2d], - `dropout_rate` float - - - - - -| Argument | Type | Description | -| -------------- | -------------------------------- | --------------------------------------------------- | -| `nO` | Optional[int] | The size of the output vectors. | -| `vectors` | Optional[Floats2d] | The vectors. | -| _keyword-only_ | | | -| `column` | int | The column of values to slice for the indices. | -| `dropout` | Optional[float] | Dropout rate to avoid overfitting (default `None`). | -| **RETURNS** | Model[Ints2d, Floats2d] | The created embedding layer. | - -```python -https://github.com/explosion/thinc/blob/master/thinc/layers/staticvectors.py -``` - ---- - ## Reduction operations {#reduction-ops} ### reduce_first {#reduce_first tag="function"} @@ -814,7 +786,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/staticvectors.py - **Input:** Ragged -- **Output:** Floats2d +- **Output:** ArrayXd @@ -823,9 +795,9 @@ item of each sequence. This is most useful after multi-head attention layers, which can learn to assign a good feature representation for the sequence to one of its elements. -| Argument | Type | Description | -| ----------- | -------------------------------- | -------------------------- | -| **RETURNS** | Model[Ragged, Floats2d] | The created pooling layer. | +| Argument | Type | Description | +| ----------- | ------------------------------- | -------------------------- | +| **RETURNS** | Model[Ragged, ArrayXd] | The created pooling layer. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_first.py @@ -841,13 +813,13 @@ representation for the sequence to its final element. - **Input:** Ragged -- **Output:** Floats2d +- **Output:** ArrayXd -| Argument | Type | Description | -| ----------- | -------------------------------- | -------------------------- | -| **RETURNS** | Model[Ragged, Floats2d] | The created pooling layer. | +| Argument | Type | Description | +| ----------- | ------------------------------- | -------------------------- | +| **RETURNS** | Model[Ragged, ArrayXd] | The created pooling layer. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_last.py @@ -930,10 +902,10 @@ to `>>` allows you to write `Relu(512) >> Softmax()` instead of Compose two or more models `f`, `g`, etc, such that their outputs are added, i.e. `add(f, g)(x)` computes `f(x) + g(x)`. -| Argument | Type | Description | -| ----------- | -------------------------------- | ---------------------- | -| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. | +| Argument | Type | Description | +| ----------- | ---------------------------- | ---------------------- | +| `*layers` | Model[Any, ArrayXd] | The models to compose. | +| **RETURNS** | Model[Any, ArrayXd] | The composed model. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/add.py @@ -955,13 +927,15 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/bidirectional.py ### chain {#chain tag="function"} -Compose two models `f` and `g` such that they become layers of a single -feed-forward model that computes `g(f(x))`. +Compose two or more models such that they become layers of a single feed-forward +model, e.g. `chain(f, g)` computes `g(f(x))`. -| Argument | Type | Description | -| ----------- | -------------------------------- | -------------------------------- | -| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed feed-forward model. | +| Argument | Type | Description | +| ----------- | -------------- | --------------------------------- | +| `layer1 ` | Model | The first model to compose. | +| `layer2` | Model | The second model to compose. | +| `*layers` | Model | Any additional models to compose. | +| **RETURNS** | Model | The composed feed-forward model. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/chain.py @@ -972,11 +946,11 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/chain.py Construct `n` copies of a layer, with distinct weights. For example, `clone(f, 3)(x)` computes `f(f'(f''(x)))`. -| Argument | Type | Description | -| ----------- | -------------------------------- | ---------------------------------- | -| `orig` | Model[ArrayXd, ArrayXd] | The layer to copy. | -| `n` | int | The number of copies to construct. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. | +| Argument | Type | Description | +| ----------- | -------------- | ------------------------------------------------ | +| `orig` | Model | The layer to copy. | +| `n` | int | The number of copies to construct. | +| **RETURNS** | Model | A composite model containing two or more copies. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/clone.py @@ -987,10 +961,10 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/clone.py Compose two or more models `f`, `g`, etc, such that their outputs are concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`. -| Argument | Type | Description | -| ----------- | -------------------------------- | ---------------------- | -| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. | +| Argument | Type | Description | +| ----------- | ------------------- | ---------------------- | +| `*layers` | Model, ... | The models to compose. | +| **RETURNS** | Model | The composed model. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/concatenate.py @@ -1025,10 +999,10 @@ and a window of surrounding vectors. This is one step in a convolution. If the concatenating three contextual vectors from the left, and three from the right, to each input vector. In general, `nO` equals `nI * (2 * window_size + 1)`. -| Argument | Type | Description | -| ------------- | ------------------------ | ------------------------------------------------------------------------------ | -| `window_size` | int | The window size (default 1) that determines the number of surrounding vectors. | -| **RETURNS** | Model[InT, InT] | The created layer for adding context to vectors. | +| Argument | Type | Description | +| ------------- | -------------------- | ------------------------------------------------------------------------------ | +| `window_size` | int | The window size (default 1) that determines the number of surrounding vectors. | +| **RETURNS** | Model[T, T] | The created layer for adding context to vectors. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/expand_window.py @@ -1038,10 +1012,10 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/expand_window.py Transform a sequences of layers into a null operation. -| Argument | Type | Description | -| ----------- | -------------------------------- | ---------------------- | -| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed model. | +| Argument | Type | Description | +| ----------- | -------------- | ---------------------- | +| `*layers` | Model | The models to compose. | +| **RETURNS** | Model | The composed model. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/noop.py @@ -1051,8 +1025,14 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/noop.py -- **Input:** List[FloatsXd], Ragged, Padded, FloatsXd -- **Output:** List[FloatsXd], Ragged, Padded, FloatsXd +- **Input:** List[FloatsXd] / Ragged / + Padded / FloatsXd + Floats1d Floats2d + Floats3d Floats4d +- **Output:** List[FloatsXd] / Ragged / + Padded / FloatsXd + Floats1d Floats2d + Floats3d Floats4d @@ -1078,10 +1058,10 @@ input to a downstream layer. On the backward pass the loss from each child is added together, so when using custom datatypes they should define an addition operator. -| Argument | Type | Description | -| ----------- | -------------------------------- | -------------------------------- | -| `*layers` | Model[ArrayXd, ArrayXd] | The models to compose. | -| **RETURNS** | Model[ArrayXd, ArrayXd] | The composed feed-forward model. | +| Argument | Type | Description | +| ----------- | ----------------------------- | -------------------------------- | +| `*layers` | Model[Any, T] ... | The models to compose. | +| **RETURNS** | Model[Any, Tuple[T]] | The composed feed-forward model. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/tuplify.py @@ -1112,12 +1092,12 @@ minibatch. The `uniqued` wrapper is useful for word inputs, because common words are seen often, but we may want to compute complicated features for the words, using e.g. character LSTM. -| Argument | Type | Description | -| -------------- | --------------------------------- | ---------------------------- | -| `layer` | Model | The layer. | -| _keyword-only_ | | | -| `column` | int | The column. Defaults to `0`. | -| **RETURNS** | Model[ArrayXd, FloatsXd] | The composed model. | +| Argument | Type | Description | +| -------------- | -------------------------------- | ---------------------------- | +| `layer` | Model | The layer. | +| _keyword-only_ | | | +| `column` | int | The column. Defaults to `0`. | +| **RETURNS** | Model[Ints2d, Floats2d] | The composed model. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/uniqued.py @@ -1153,7 +1133,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/array_getitem.py -- **Input:** List[Array2d] +- **Input:** List2d - **Output:** Array2d @@ -1162,9 +1142,9 @@ Transform sequences to ragged arrays if necessary. If sequences are already ragged, do nothing. A ragged array is a tuple `(data, lengths)`, where `data` is the concatenated data. -| Argument | Type | Description | -| ----------- | -------------------------------------- | ---------------------------------------- | -| **RETURNS** | Model[List[Array2d], Array2d] | The layer to compute the transformation. | +| Argument | Type | Description | +| ----------- | ------------------------------- | ---------------------------------------- | +| **RETURNS** | Model[List2d, Array2d] | The layer to compute the transformation. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/list2array.py @@ -1174,7 +1154,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2array.py -- **Input:** List[Floats2d] +- **Input:** ListXd - **Output:** Ragged @@ -1183,9 +1163,9 @@ Transform sequences to ragged arrays if necessary and return the ragged array. If sequences are already ragged, do nothing. A ragged array is a tuple `(data, lengths)`, where `data` is the concatenated data. -| Argument | Type | Description | -| ----------- | ------------------------------------- | ---------------------------------------- | -| **RETURNS** | Model[List[Array2d], Ragged] | The layer to compute the transformation. | +| Argument | Type | Description | +| ----------- | ------------------------------ | ---------------------------------------- | +| **RETURNS** | Model[ListXd, Ragged] | The layer to compute the transformation. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/list2ragged.py @@ -1195,7 +1175,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2ragged.py -- **Input:** List[Array2d] +- **Input:** List2d - **Output:** Padded @@ -1203,9 +1183,9 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2ragged.py Create a layer to convert a list of array inputs into [`Padded`](/docs/api-types#padded). -| Argument | Type | Description | -| ----------- | ------------------------------------- | ---------------------------------------- | -| **RETURNS** | Model[List[Array2d], Padded] | The layer to compute the transformation. | +| Argument | Type | Description | +| ----------- | ------------------------------ | ---------------------------------------- | +| **RETURNS** | Model[List2d, Padded] | The layer to compute the transformation. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/list2padded.py @@ -1216,15 +1196,15 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/list2padded.py - **Input:** Ragged -- **Output:** List[Floats2d] +- **Output:** ListXd Transform sequences from a ragged format into lists. -| Argument | Type | Description | -| ----------- | -------------------------------------- | ---------------------------------------- | -| **RETURNS** | Model[Ragged, List[Floats2d]] | The layer to compute the transformation. | +| Argument | Type | Description | +| ----------- | ------------------------------ | ---------------------------------------- | +| **RETURNS** | Model[Ragged, ListXd] | The layer to compute the transformation. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/ragged2list.py @@ -1235,16 +1215,16 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/ragged2list.py - **Input:** Padded -- **Output:** List[Array] +- **Output:** List2d Create a layer to convert a [`Padded`](/docs/api-types#padded) input into a list of arrays. -| Argument | Type | Description | -| ----------- | ----------------------------------- | ---------------------------------------- | -| **RETURNS** | Model[Padded, List[Array]] | The layer to compute the transformation. | +| Argument | Type | Description | +| ----------- | ------------------------------ | ---------------------------------------- | +| **RETURNS** | Model[Padded, List2d] | The layer to compute the transformation. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/padded2list.py @@ -1298,7 +1278,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/strings2arrays.py -- **Input / output:** Union[Padded, Ragged, List[Array2d], ArrayXd] +- **Input / output:** Union[Padded, Ragged, ListXd, ArrayXd] @@ -1308,7 +1288,7 @@ input is an array, it is passed through unchanged. | Argument | Type | Description | | -------------- | -------------------------------- | ----------------------------- | -| `layer` | Model[Array2d, Array2d] | The layer to wrap. | +| `layer` | Model[ArrayXd, ArrayXd] | The layer to wrap. | | _keyword-only_ | | | | `pad` | int | The padding. Defaults to `0`. | | **RETURNS** | Model | The wrapped layer. | @@ -1321,7 +1301,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_array2d.py -- **Input / output:** Union[Padded, Ragged, List[Array2d], Array2d] +- **Input / output:** Union[Padded, Ragged, List2d, Array2d] @@ -1348,17 +1328,17 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_array.py - **Input:** Sequence[Sequence[Any]] -- **Output:** List[Array2d] +- **Output:** ListXd Flatten nested inputs on the way into a layer and reverse the transformation over the outputs. -| Argument | Type | Description | -| ----------- | -------------- | ------------------ | -| `layer` | Model | The layer to wrap. | -| **RETURNS** | Model | The wrapped layer. | +| Argument | Type | Description | +| ----------- | ---------------------------------------------------------------- | ------------------ | +| `layer` | Model[Sequence[Sequence[Any]], Sequence[Sequence[Any]]] | The layer to wrap. | +| **RETURNS** | Model[ListXd, ListXd] | The wrapped layer. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/with_flatten.py @@ -1368,7 +1348,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_flatten.py -- **Input / output:** Union[Padded, Ragged, List[Array2d], Floats3d, +- **Input / output:** Union[Padded, Ragged, List2d, Floats3d, Tuple[Floats3d, Ints1d, Ints1d, Ints1d]] @@ -1389,7 +1369,7 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_padded.py -- **Input / output:** Union[Padded, Ragged, List[Array2d], Floats3d, +- **Input / output:** Union[Padded, Ragged, ListXd, Floats3d, Tuple[Floats2d, Ints1d]] @@ -1410,17 +1390,17 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/with_ragged.py -- **Input / output:** Union[Padded, Ragged, List[Array2d]] +- **Input / output:** Union[Padded, Ragged, List2d] Convert sequence input into lists on the way into a layer and reverse the transformation on the outputs. -| Argument | Type | Description | -| ----------- | -------------------------------------------- | ------------------ | -| `layer` | Model[List[Array2d], List[Array2d]] | The layer to wrap. | -| **RETURNS** | Model | The wrapped layer. | +| Argument | Type | Description | +| ----------- | ------------------------------ | ------------------ | +| `layer` | Model[List2d, List2d] | The layer to wrap. | +| **RETURNS** | Model | The wrapped layer. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/with_list.py