diff --git a/.github/workflows/autoblack.yml b/.github/workflows/autoblack.yml new file mode 100644 index 000000000..4109acce7 --- /dev/null +++ b/.github/workflows/autoblack.yml @@ -0,0 +1,44 @@ +# GitHub Action that uses Black to reformat all Python code and submits a PR +# in regular intervals. Inspired by: https://github.com/cclauss/autoblack + +name: autoblack +on: + workflow_dispatch: # allow manual trigger + schedule: + - cron: '0 8 * * 5' # every Friday at 8am UTC + +jobs: + autoblack: + if: github.repository_owner == 'explosion' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + ref: ${{ github.head_ref }} + - uses: actions/setup-python@v2 + - run: pip install black + - name: Auto-format code if needed + run: black thinc + # We can't run black --check here because that returns a non-zero excit + # code and makes GitHub think the action failed + - name: Check for modified files + id: git-check + run: echo ::set-output name=modified::$(if git diff-index --quiet HEAD --; then echo "false"; else echo "true"; fi) + - name: Create Pull Request + if: steps.git-check.outputs.modified == 'true' + uses: peter-evans/create-pull-request@v3 + with: + title: Auto-format code with black + labels: meta + commit-message: Auto-format code with black + committer: GitHub + author: explosion-bot + body: _This PR is auto-generated._ + branch: autoblack + delete-branch: true + draft: false + - name: Check outputs + if: steps.git-check.outputs.modified == 'true' + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" diff --git a/.github/workflows/explosionbot.yml b/.github/workflows/explosionbot.yml index 9c5be1366..b32b65052 100644 --- a/.github/workflows/explosionbot.yml +++ b/.github/workflows/explosionbot.yml @@ -23,5 +23,5 @@ jobs: env: INPUT_TOKEN: ${{ secrets.EXPLOSIONBOT_TOKEN }} INPUT_BK_TOKEN: ${{ secrets.BUILDKITE_SECRET }} - ENABLED_COMMANDS: "test_gpu" + ENABLED_COMMANDS: "test_gpu,test_slow,test_slow_gpu" ALLOWED_TEAMS: "spacy-maintainers" \ No newline at end of file diff --git a/README.md b/README.md index b0c7f3937..488f5c9db 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries -### From the makers of [spaCy](https://spacy.io), [Prodigy](https://prodi.gy) and [FastAPI](https://fastapi.tiangolo.com) +### From the makers of [spaCy](https://spacy.io) and [Prodigy](https://prodi.gy) [Thinc](https://thinc.ai) is a **lightweight deep learning library** that offers an elegant, type-checked, functional-programming API for **composing models**, with support diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 301ab7341..3789606a4 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -23,7 +23,7 @@ jobs: imageName: 'windows-2019' python.version: '3.6' Python37Mac: - imageName: 'macos-10.15' + imageName: 'macos-latest' python.version: '3.7' Python38Linux: imageName: 'ubuntu-latest' @@ -63,6 +63,7 @@ jobs: - script: | python -m mypy thinc displayName: 'Run mypy' + condition: ne(variables['python.version'], '3.6') - task: DeleteFiles@1 inputs: @@ -82,25 +83,37 @@ jobs: - script: | pip install -r requirements.txt - pip install "tensorflow~=2.5.0" - pip install "mxnet; sys_platform != 'win32'" - pip install "torch==1.9.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html pip install ipykernel pydot graphviz python -m ipykernel install --name thinc-notebook-tests --user - displayName: 'Install test dependencies' + python -m pytest --pyargs thinc --cov=thinc --cov-report=term + displayName: 'Run tests without extras' + + - script: | + pip install "protobuf~=3.20.0" "tensorflow~=2.5.0" + pip install "mxnet; sys_platform != 'win32'" + pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + # torch does not have a direct numpy requirement but is compiled against + # a newer version than the oldest supported numpy for windows and + # python 3.10; this version of numpy would not work with + # tensorflow~=2.5.0 as specified above, but there is no release for + # python 3.10 anyway + pip install "numpy~=1.23.0; python_version=='3.10' and sys_platform=='win32'" + pip install -r requirements.txt + pip uninstall -y mypy + displayName: 'Install extras for testing' - script: | python -m pytest --pyargs thinc --cov=thinc --cov-report=term - displayName: 'Run tests' + displayName: 'Run tests with extras' - script: | pip uninstall -y tensorflow pip install thinc-apple-ops python -m pytest --pyargs thinc_apple_ops displayName: 'Run tests for thinc-apple-ops' - condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.9')) + condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.10')) - script: | python -m pytest --pyargs thinc displayName: 'Run tests with thinc-apple-ops' - condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.9')) + condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.10')) diff --git a/build-constraints.txt b/build-constraints.txt index cf5fe3284..956973abf 100644 --- a/build-constraints.txt +++ b/build-constraints.txt @@ -1,6 +1,8 @@ # build version constraints for use with wheelwright + multibuild -numpy==1.15.0; python_version<='3.7' -numpy==1.17.3; python_version=='3.8' +numpy==1.15.0; python_version<='3.7' and platform_machine!='aarch64' +numpy==1.19.2; python_version<='3.7' and platform_machine=='aarch64' +numpy==1.17.3; python_version=='3.8' and platform_machine!='aarch64' +numpy==1.19.2; python_version=='3.8' and platform_machine=='aarch64' numpy==1.19.3; python_version=='3.9' numpy==1.21.3; python_version=='3.10' numpy; python_version>='3.11' diff --git a/examples/transformers_tagger.py b/examples/transformers_tagger.py index 88052ba1f..058d5af24 100644 --- a/examples/transformers_tagger.py +++ b/examples/transformers_tagger.py @@ -132,7 +132,9 @@ def forward( return TokensPlus(**token_data), lambda d_tokens: [] return Model( - "tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)}, + "tokenizer", + forward, + attrs={"tokenizer": AutoTokenizer.from_pretrained(name)}, ) @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train): def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs: # Restore entries for bos and eos markers. + shim = model.shims[0] row = model.ops.alloc2f(1, d_tokvecs[0].shape[1]) d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs] return ArgsKwargs( args=(torch_tokvecs,), - kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))}, + kwargs={ + "grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device)) + }, ) return tokvecs, backprop diff --git a/pyproject.toml b/pyproject.toml index ff13274ee..d3fb69b76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "murmurhash>=1.0.2,<1.1.0", "cymem>=2.0.2,<2.1.0", "preshed>=3.0.2,<3.1.0", - "blis>=0.4.0,<0.8.0", + "blis>=0.7.8,<0.8.0", "numpy>=1.15.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 4a31fc017..3f9ab983b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,17 +2,18 @@ murmurhash>=1.0.2,<1.1.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 -blis>=0.4.0,<0.8.0 +blis>=0.7.8,<0.8.0 srsly>=2.4.0,<3.0.0 wasabi>=0.8.1,<1.1.0 catalogue>=2.0.4,<2.1.0 +confection>=0.0.1,<1.0.0 ml_datasets>=0.2.0,<0.3.0 # Third-party dependencies -pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0 +pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0 numpy>=1.15.0 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" -typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8" +typing_extensions>=3.7.4.1,<4.2.0; python_version < "3.8" contextvars>=2.4,<3; python_version < "3.7" # Development dependencies cython>=0.25.0,<3.0 @@ -22,7 +23,7 @@ pytest-cov>=2.7.0,<2.8.0 coverage>=5.0.0,<6.0.0 mock>=2.0.0,<3.0.0 flake8>=3.5.0,<3.6.0 -mypy>=0.901,<0.960 +mypy>=0.980,<0.990; platform_machine != "aarch64" and python_version >= "3.7" types-mock>=0.1.1 types-contextvars>=0.1.2; python_version < "3.7" types-dataclasses>=0.1.3; python_version < "3.7" @@ -33,3 +34,4 @@ nbconvert>=5.6.1,<6.2.0 nbformat>=5.0.4,<5.2.0 # Test to_disk/from_disk against pathlib.Path subclasses pathy>=0.3.5 +black>=22.0,<23.0 diff --git a/setup.cfg b/setup.cfg index 4e535d899..a0347c55e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,24 +35,29 @@ setup_requires = cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 murmurhash>=1.0.2,<1.1.0 - blis>=0.4.0,<0.8.0 + blis>=0.7.8,<0.8.0 install_requires = # Explosion-provided dependencies - blis>=0.4.0,<0.8.0 + blis>=0.7.8,<0.8.0 murmurhash>=1.0.2,<1.1.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 wasabi>=0.8.1,<1.1.0 srsly>=2.4.0,<3.0.0 catalogue>=2.0.4,<2.1.0 + confection>=0.0.1,<1.0.0 # Third-party dependencies setuptools numpy>=1.15.0 - pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0 + pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" - typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8" + typing_extensions>=3.7.4.1,<4.2.0; python_version < "3.8" contextvars>=2.4,<3; python_version < "3.7" + +[options.entry_points] +pytest_randomly.random_seeder = + thinc = thinc.api:fix_random_seed [options.extras_require] cuda = @@ -83,6 +88,14 @@ cuda114 = cupy-cuda114>=5.0.0b4 cuda115 = cupy-cuda115>=5.0.0b4 +cuda116 = + cupy-cuda116>=5.0.0b4 +cuda117 = + cupy-cuda117>=5.0.0b4 +cuda11x = + cupy-cuda11x>=11.0.0 +cuda-autodetect = + cupy-wheel>=11.0.0 datasets = ml_datasets>=0.2.0,<0.3.0 torch = diff --git a/setup.py b/setup.py index 3ce787cdc..27873beeb 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ PACKAGES = find_packages() MOD_NAMES = [ + "thinc.backends.cblas", "thinc.backends.linalg", "thinc.backends.numpy_ops", "thinc.extra.search", @@ -24,7 +25,7 @@ ] COMPILE_OPTIONS = { "msvc": ["/Ox", "/EHsc"], - "other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function"], + "other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function", "-std=c++11"], } COMPILER_DIRECTIVES = { "language_level": -3, diff --git a/thinc/about.py b/thinc/about.py index 0c63069de..35cdb6adb 100644 --- a/thinc/about.py +++ b/thinc/about.py @@ -1,2 +1,2 @@ -__version__ = "8.0.15" +__version__ = "8.1.3" __release__ = True diff --git a/thinc/api.py b/thinc/api.py index 8e1ea4bf9..8c5807347 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -16,8 +16,10 @@ from .util import DataValidationError, data_validation from .util import to_categorical, get_width, get_array_module, to_numpy from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet +from .util import get_torch_default_device +from .compat import has_cupy from .backends import get_ops, set_current_ops, get_current_ops, use_ops -from .backends import Ops, CupyOps, NumpyOps, has_cupy, set_gpu_allocator +from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear @@ -25,7 +27,7 @@ from .layers import CauchySimilarity, ParametricAttention, Logistic from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear from .layers import ClippedLinear, ReluK, HardTanh, HardSigmoid -from .layers import HardSwish, HardSwishMobilenet, Swish, Gelu +from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper from .layers import PyTorchWrapper_v2, Softmax_v2 @@ -38,6 +40,7 @@ from .layers import with_reshape, with_getitem, strings2arrays, list2array from .layers import list2ragged, ragged2list, list2padded, padded2list, remap_ids from .layers import array_getitem, with_cpu, with_debug, with_nvtx_range +from .layers import with_signpost_interval from .layers import tuplify from .layers import reduce_first, reduce_last, reduce_max, reduce_mean, reduce_sum diff --git a/thinc/backends/__init__.py b/thinc/backends/__init__.py index ba4daaa8a..c21620126 100644 --- a/thinc/backends/__init__.py +++ b/thinc/backends/__init__.py @@ -5,13 +5,15 @@ import threading from .ops import Ops -from .cupy_ops import CupyOps, has_cupy +from .cupy_ops import CupyOps from .numpy_ops import NumpyOps +from .mps_ops import MPSOps from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator from ._param_server import ParamServer from ..util import assert_tensorflow_installed, assert_pytorch_installed -from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu +from ..util import get_torch_default_device, is_cupy_array, require_cpu from .. import registry +from ..compat import cupy, has_cupy context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None) @@ -46,9 +48,11 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover We'd like to support routing Tensorflow memory allocation via PyTorch as well (or vice versa), but do not currently have an implementation for it. """ - import cupy.cuda - assert_pytorch_installed() + + if get_torch_default_device().type != "cuda": + return + pools = context_pools.get() if "pytorch" not in pools: pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator) @@ -65,8 +69,6 @@ def use_tensorflow_for_gpu_memory() -> None: # pragma: no cover We'd like to support routing PyTorch memory allocation via Tensorflow as well (or vice versa), but do not currently have an implementation for it. """ - import cupy.cuda - assert_tensorflow_installed() pools = context_pools.get() if "tensorflow" not in pools: @@ -94,7 +96,7 @@ def get_ops(name: str, **kwargs) -> Ops: cls: Optional[Callable[..., Ops]] = None if name == "cpu": - _import_extra_cpu_backends() + _import_extra_cpu_backends() cls = ops_by_name.get("numpy") cls = ops_by_name.get("apple", cls) cls = ops_by_name.get("bigendian", cls) @@ -137,7 +139,6 @@ def set_current_ops(ops: Ops) -> None: """Change the current backend object.""" context_ops.set(ops) _get_thread_state().ops = ops - set_torch_tensor_type_for_ops(ops) def contextvars_eq_thread_ops() -> bool: @@ -173,6 +174,7 @@ def _create_thread_local( "ParamServer", "Ops", "CupyOps", + "MPSOps", "NumpyOps", "has_cupy", ] diff --git a/thinc/backends/_cupy_allocators.py b/thinc/backends/_cupy_allocators.py index 7bcb8bd5c..f2b6faee9 100644 --- a/thinc/backends/_cupy_allocators.py +++ b/thinc/backends/_cupy_allocators.py @@ -1,23 +1,8 @@ from typing import cast from ..types import ArrayXd -from ..util import tensorflow2xp - -try: - import tensorflow -except ImportError: - pass - -try: - import torch -except ImportError: - pass - -try: - from cupy.cuda.memory import MemoryPointer - from cupy.cuda.memory import UnownedMemory -except ImportError: - pass +from ..util import get_torch_default_device, tensorflow2xp +from ..compat import torch, cupy, tensorflow def cupy_tensorflow_allocator(size_in_bytes: int): @@ -32,12 +17,13 @@ def cupy_tensorflow_allocator(size_in_bytes: int): cupy_array = cast(ArrayXd, tensorflow2xp(tensor)) address = int(cupy_array.data) # cupy has a neat class to help us here. Otherwise it will try to free. - memory = UnownedMemory(address, size_in_bytes, cupy_array) + memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, cupy_array) # Now return a new memory pointer. - return MemoryPointer(memory, 0) + return cupy.cuda.memory.MemoryPointer(memory, 0) def cupy_pytorch_allocator(size_in_bytes: int): + device = get_torch_default_device() """Function that can be passed into cupy.cuda.set_allocator, to have cupy allocate memory via PyTorch. This is important when using the two libraries together, as otherwise OOM errors can occur when there's available memory @@ -49,10 +35,12 @@ def cupy_pytorch_allocator(size_in_bytes: int): # creating a whole Tensor. # This turns out to be way faster than making FloatStorage? Maybe # a Python vs C++ thing I guess? - torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False) + torch_tensor = torch.zeros( + (size_in_bytes // 4,), requires_grad=False, device=device + ) # cupy has a neat class to help us here. Otherwise it will try to free. # I think this is a private API? It's not in the types. address = torch_tensor.data_ptr() # type: ignore - memory = UnownedMemory(address, size_in_bytes, torch_tensor) + memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, torch_tensor) # Now return a new memory pointer. - return MemoryPointer(memory, 0) + return cupy.cuda.memory.MemoryPointer(memory, 0) diff --git a/thinc/backends/_custom_kernels.cu b/thinc/backends/_custom_kernels.cu index f0cebd77d..9c9fece1e 100644 --- a/thinc/backends/_custom_kernels.cu +++ b/thinc/backends/_custom_kernels.cu @@ -22,6 +22,26 @@ struct Constants { }; +template +__global__ void gather_add(U* out_bo, const U* table_to, const int* indices_bk, + int T, int O, int B, int K) +{ + int _loop_start = blockIdx.x * blockDim.x + threadIdx.x; + int _loop_stride = blockDim.x * gridDim.x; + + for (int b = _loop_start; b < B; b += _loop_stride) { + for (int k = 0; k < K; ++k) { + int idx = indices_bk[b * K + k]; + const U* table = table_to + idx * O; + U* out = out_bo + b * O; + for (int o = 0; o < O; ++o) { + out[o] += table[o]; + } + } + } +} + + template __global__ void seq2col(T* output, const T* X, const int* lengths, int nW, int B, int I, int nL) @@ -141,6 +161,20 @@ __global__ void clipped_linear(T* Y, const T* X, double slope, double offset, do } +template +__global__ void dish(T* Y, const T* X, int N) +{ + int _loop_start = blockIdx.x * blockDim.x + threadIdx.x; + int _loop_stride = blockDim.x * gridDim.x; + + for (int i = _loop_start; i < N; i += _loop_stride) + { + T x = X[i]; + Y[i] = 0.5 * x * (x / sqrt(1 + x * x) + 1); + } +} + + template __global__ void gelu(T* Y, const T* X, double threshold, int N) { @@ -394,6 +428,23 @@ __global__ void backprop_hard_swish_mobilenet(T* dX, const T* dY, const T* X, in } +template +__global__ void backprop_dish(T* dX, const T* dY, const T* X, int N) +{ + + int _loop_start = blockIdx.x * blockDim.x + threadIdx.x; + int _loop_stride = blockDim.x * gridDim.x; + + for (int i = _loop_start; i < N; i += _loop_stride) + { + T x = X[i]; + T x_sq = x * x; + T x_sq_plus_one = x_sq + 1.0; + dX[i] = dY[i] * (x/sqrt(x_sq_plus_one) - (0.5 * x * x_sq) + / pow(x_sq_plus_one, static_cast(1.5)) + 0.5); + } +} + template __global__ void backprop_gelu(T* dX, const T* dY, const T* X, diff --git a/thinc/backends/_custom_kernels.py b/thinc/backends/_custom_kernels.py index d5456114a..859405495 100644 --- a/thinc/backends/_custom_kernels.py +++ b/thinc/backends/_custom_kernels.py @@ -2,11 +2,7 @@ import re from pathlib import Path from collections import defaultdict - -try: - import cupy -except ImportError: - cupy = None +from ..compat import cupy, has_cupy_gpu PWD = Path(__file__).parent @@ -14,6 +10,8 @@ KERNELS_LIST = [ "backprop_clipped_linear", "backprop_clipped_linear", + "backprop_dish", + "backprop_dish", "backprop_gelu", "backprop_gelu", "backprop_hard_swish", @@ -36,6 +34,10 @@ "backprop_swish", "clipped_linear", "clipped_linear", + "dish", + "dish", + "gather_add", + "gather_add", "gelu", "gelu", "maxout", @@ -55,7 +57,7 @@ cupy.RawModule( code=KERNELS_SRC, options=("--std=c++11",), name_expressions=KERNELS_LIST ) - if cupy is not None + if has_cupy_gpu else None ) @@ -70,7 +72,7 @@ def _get_kernel(name): def compile_mmh(src): - if cupy is None: + if not has_cupy_gpu: return None return cupy.RawKernel(src, "hash_data") @@ -80,6 +82,10 @@ def compile_mmh(src): clipped_linear_kernel_float = _get_kernel("clipped_linear") clipped_linear_kernel_double = _get_kernel("clipped_linear") +dish_kernel_float = _get_kernel("dish") +dish_kernel_double = _get_kernel("dish") +gather_add_kernel_float = _get_kernel("gather_add") +gather_add_kernel_double = _get_kernel("gather_add") gelu_kernel_float = _get_kernel("gelu") gelu_kernel_double = _get_kernel("gelu") hash_data_kernel = compile_mmh(MMH_SRC) @@ -98,6 +104,8 @@ def compile_mmh(src): backprop_clipped_linear_kernel_double = _get_kernel("backprop_clipped_linear") backprop_clipped_linear_kernel_float = _get_kernel("backprop_clipped_linear") +backprop_dish_kernel_double = _get_kernel("backprop_dish") +backprop_dish_kernel_float = _get_kernel("backprop_dish") backprop_gelu_kernel_double = _get_kernel("backprop_gelu") backprop_gelu_kernel_float = _get_kernel("backprop_gelu") backprop_hard_swish_kernel_double = _get_kernel("backprop_hard_swish") @@ -169,6 +177,49 @@ def clipped_linear( return out +def gather_add(table, indices, *, threads_per_block=128, num_blocks=128): + if table.ndim != 2: + raise ValueError( + f"gather_add expects table with dimensionality 2, was: {table.ndim}" + ) + if indices.ndim != 2: + raise ValueError( + f"gather_add expects indices with dimensionality 2, was: {indices.ndim}" + ) + _is_float_array(table) + indices = indices.astype("int32") + _check_indices(indices, table.shape[0]) + + B = indices.shape[0] + K = indices.shape[1] + T = table.shape[0] + O = table.shape[1] + + out = _alloc((B, O), dtype=table.dtype, zeros=True) + if table.dtype == "float32": + gather_add_kernel_float( + (num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K) + ) + else: + gather_add_kernel_double( + (num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K) + ) + return out + + +def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128): + _is_float_array(X) + + out = X + if not inplace: + out = _alloc_like(X, zeros=False) + if X.dtype == "float32": + dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size)) + else: + dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size)) + return out + + def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128): _is_float_array(X) @@ -453,6 +504,33 @@ def backprop_hard_swish_mobilenet( return out +def backprop_dish( + dY, + X, + *, + inplace: bool = False, + threads_per_block=128, + num_blocks=128, +): + _is_float_array(dY) + _is_float_array(X, shape=dY.shape) + + out = dY + if not inplace: + out = _alloc_like(dY, zeros=False) + + if dY.dtype == "float32": + backprop_dish_kernel_float( + (num_blocks,), (threads_per_block,), (out, dY, X, out.size) + ) + else: + backprop_dish_kernel_double( + (num_blocks,), (threads_per_block,), (out, dY, X, out.size) + ) + + return out + + def backprop_gelu( dY, X, @@ -651,6 +729,13 @@ def _check_lengths(lengths, n_elems: int, *, min_length=0): raise IndexError("lengths must sum up to the batch size") +def _check_indices(indices, n: int): + assert indices.dtype == "int32", "indices should be encoded as 32-bit integers" + + if not _values_within_range(indices, 0, n): + raise IndexError(f"index out of bounds, must be >= 0 && < {n}") + + def _check_which_maxout(which, B: int, I: int, P: int): shape = (B, I) msg = "maximum index (which) should be encoded as 32-bit integers" @@ -672,7 +757,7 @@ def _check_which_maxout(which, B: int, I: int, P: int): "true", "within_range", ) - if cupy is not None + if has_cupy_gpu else None ) diff --git a/thinc/backends/cblas.pxd b/thinc/backends/cblas.pxd new file mode 100644 index 000000000..15837e5e7 --- /dev/null +++ b/thinc/backends/cblas.pxd @@ -0,0 +1,38 @@ +from libcpp.memory cimport shared_ptr + + +ctypedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K, + float alpha, const float* A, int lda, const float *B, + int ldb, float beta, float* C, int ldc) nogil + + +ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX, + float *Y, int incY) nogil + + +ctypedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX, + double *Y, int incY) nogil + + +# Forward-declaration of the BlasFuncs struct. This struct must be opaque, so +# that consumers of the CBlas class cannot become dependent on its size or +# ordering. +cdef struct BlasFuncs + + +cdef class CBlas: + cdef shared_ptr[BlasFuncs] ptr + + +# Note: the following functions are intentionally standalone. If we make them +# methods of CBlas, Cython will generate and use a vtable. This makes it +# impossible to add new BLAS functions later without breaking the ABI. +# +# See https://github.com/explosion/thinc/pull/700 for more information. + +cdef daxpy_ptr daxpy(CBlas cblas) nogil +cdef saxpy_ptr saxpy(CBlas cblas) nogil +cdef sgemm_ptr sgemm(CBlas cblas) nogil +cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil +cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil +cdef void set_sgemm(CBlas cblas, sgemm_ptr sgemm) nogil diff --git a/thinc/backends/cblas.pyx b/thinc/backends/cblas.pyx new file mode 100644 index 000000000..9eb4514d8 --- /dev/null +++ b/thinc/backends/cblas.pyx @@ -0,0 +1,40 @@ +cimport blis.cy +from cython.operator cimport dereference as deref +from libcpp.memory cimport make_shared + + +cdef struct BlasFuncs: + daxpy_ptr daxpy + saxpy_ptr saxpy + sgemm_ptr sgemm + + +cdef class CBlas: + __slots__ = [] + + def __init__(self): + """Construct a CBlas instance set to use BLIS implementations of the + supported BLAS functions.""" + cdef BlasFuncs funcs + funcs.daxpy = blis.cy.daxpy + funcs.saxpy = blis.cy.saxpy + funcs.sgemm = blis.cy.sgemm + self.ptr = make_shared[BlasFuncs](funcs) + +cdef daxpy_ptr daxpy(CBlas cblas) nogil: + return deref(cblas.ptr).daxpy + +cdef saxpy_ptr saxpy(CBlas cblas) nogil: + return deref(cblas.ptr).saxpy + +cdef sgemm_ptr sgemm(CBlas cblas) nogil: + return deref(cblas.ptr).sgemm + +cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil: + deref(cblas.ptr).daxpy = daxpy + +cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil: + deref(cblas.ptr).saxpy = saxpy + +cdef void set_sgemm(CBlas cblas, sgemm_ptr sgemm) nogil: + deref(cblas.ptr).sgemm = sgemm diff --git a/thinc/backends/cpu_kernels.hh b/thinc/backends/cpu_kernels.hh index 69e1ab334..95808eee1 100644 --- a/thinc/backends/cpu_kernels.hh +++ b/thinc/backends/cpu_kernels.hh @@ -8,26 +8,77 @@ #include #include +// Ideally we'd use an alias declaration for a generic definition of +// *axpy. But Cython doesn't support alias declarations yet: +// +// https://github.com/cython/cython/issues/3272 +// +// template +// using axpy = void (*)(int N, T alpha, const T* X, int incX, +// T *Y, int incY); +// +// So, instead we'll do this the pre-C++11 way: + +template +struct axpy { + typedef void (*ptr)(int N, T alpha, const T* X, int incX, T *Y, int incY); +}; + // All elementwise functions, such as most activations, work in-place. -template -L argmax(A* arr, L len) + +template +struct argmax_result { + T max; + L max_idx; +}; + +template +argmax_result argmax(T const *arr, L len) { - static_assert(std::is_floating_point::value, + static_assert(std::is_floating_point::value, "Array should be floating point"); static_assert(std::is_integral::value, "Array length should be integral"); - L max = 0; + argmax_result r { arr[0], 0 }; + for (L i = 1; i < len; ++i) { - if (arr[i] > arr[max]) { - max = i; + if (arr[i] > r.max) { + r.max = arr[i]; + r.max_idx = i; } } - return max; + return r; +} + +// The next two templates define argmax for a fixed number of elements. + +template +argmax_result argmax(T a) { + static_assert(std::is_floating_point::value, "Argument should be floating point"); + argmax_result acc { a, 0 }; + return acc; +} + +template +argmax_result argmax(T a, Args... args) { + static_assert(std::is_floating_point::value, "Arguments should be floating point"); + + auto acc = argmax(args...); + + if (acc.max > a) { + acc.max_idx += 1; + } else { + acc.max_idx = 0; + acc.max = a; + } + + return acc; } + template void vec_add(A* X, const A* Y, A scale, L N) { @@ -46,12 +97,31 @@ void cpu_maxout(A* best__bo, L* which__bo, const A* cands__bop, L B, L O, L P) "Array should be floating point"); static_assert(std::is_integral::value, "Array length should be integral"); - for (int i = 0; i < B * O; ++i) { - which__bo[i] = argmax(cands__bop + i * P, P); - best__bo[i] = cands__bop[i * P + which__bo[i]]; + // For small inputs, we use an unrolled argmax. + if (P == 2) { + for (int i = 0; i < B * O; ++i) { + A const *input = cands__bop + i * P; + auto r = argmax(input[0], input[1]); + which__bo[i] = r.max_idx; + best__bo[i] = r.max; + } + } else if (P == 3) { + for (int i = 0; i < B * O; ++i) { + A const *input = cands__bop + i * P; + auto r = argmax(input[0], input[1], input[2]); + which__bo[i] = r.max_idx; + best__bo[i] = r.max; + } + } else { + for (int i = 0; i < B * O; ++i) { + auto r = argmax(cands__bop + i * P, P); + which__bo[i] = r.max_idx; + best__bo[i] = r.max; + } } } + template void cpu_backprop_maxout(A* dX__bop, const A* dX__bo, const L* which__bo, L B, L O, L P) @@ -395,4 +465,18 @@ void backprop_seq2col(A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L } } +template +void cpu_gather_add(typename axpy::ptr axpy, F* out_bo, const F* table_to, const I* indices_bk, L T, L O, L B, L K) { + for (L b = 0; b < B; ++b) { + for (L k = 0; k < K; ++k) { + I idx = indices_bk[b * K + k]; + if (idx > T) { + throw std::out_of_range("Embedding index out-of-bounds"); + } + axpy(O, 1.0, table_to + idx * O, 1, out_bo + b * O, 1); + } + } +} + + #endif // CPU_KERNELS_HH diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 709775ddd..6d263c155 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -1,19 +1,4 @@ import numpy - -try: - import cupy - import cupyx - import cupy.cuda - from cupy.cuda.compiler import compile_with_cache # noqa: F401 - - has_cupy = True - - # We no longer have to set up the memory pool, fortunately. -except ImportError: - cupy = None - cupyx = None - has_cupy = False - from .. import registry from .ops import Ops from .numpy_ops import NumpyOps @@ -21,7 +6,8 @@ from ..types import DeviceTypes from ..util import torch2xp, tensorflow2xp, mxnet2xp from ..util import is_cupy_array -from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array +from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array +from ..compat import cupy, cupyx @registry.ops("CupyOps") @@ -44,6 +30,24 @@ def to_numpy(self, data, *, byte_order=None): data = numpy.asarray(data, dtype=dtype) return data + def gather_add(self, table, indices): + if table.dtype in ("float32", "float64"): + return _custom_kernels.gather_add(table, indices) + else: + return super().gather_add(table, indices) + + def dish(self, X, inplace=False): + if X.dtype in ("float32", "float64"): + return _custom_kernels.dish(X, inplace=inplace) + else: + return super().dish(X, inplace=inplace) + + def backprop_dish(self, dY, X, inplace=False): + if X.dtype == dY.dtype and X.dtype in ("float32", "float64"): + return _custom_kernels.backprop_dish(dY, X, inplace=inplace) + else: + return super().backprop_dish(dY, X, inplace=inplace) + def gelu(self, X, inplace=False): if X.dtype in ("float32", "float64"): return _custom_kernels.gelu(X, inplace=inplace, threshold=6.0) @@ -73,29 +77,20 @@ def gemm(self, x, y, out=None, trans1=False, trans2=False): return out def asarray(self, data, dtype=None): - # This is sort of frustrating, but we can't easily otherwise pass - # forward "unset". - dtype = {"dtype": dtype} if dtype is not None else {} - # We'll try to perform a zero-copy conversion if possible. - array = None - cast_array = False if is_cupy_array(data): - array = self.xp.asarray(data, **dtype) - elif is_torch_gpu_array(data): + array = data + elif is_torch_cuda_array(data): array = torch2xp(data) - cast_array = True elif is_tensorflow_gpu_array(data): array = tensorflow2xp(data) - cast_array = True elif is_mxnet_gpu_array(data): array = mxnet2xp(data) - cast_array = True else: - array = self.xp.array(data, **dtype) + array = self.xp.array(data) - if cast_array and dtype != {}: - array = array.astype(dtype["dtype"]) + if dtype is not None: + array = array.astype(dtype=dtype, copy=False) return array @@ -304,6 +299,10 @@ def scatter_add(self, table, indices, values): def adam( self, weights, gradient, mom1, mom2, beta1, beta2, eps, learn_rate, mod_rate=1.0 ): + _check_compatible_shape(weights, gradient) + _check_compatible_shape(weights, mom1) + _check_compatible_shape(weights, mom2) + adam_kernel( gradient, learn_rate, 1 - beta1, 1 - beta2, eps, weights, mom1, mom2 ) @@ -326,3 +325,9 @@ def position_encode(self, N, D, period=10000, out=None): ) else: adam_kernel = None + + +def _check_compatible_shape(u, v): + if u.shape != v.shape: + msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}" + raise ValueError(msg) diff --git a/thinc/backends/mps_ops.py b/thinc/backends/mps_ops.py new file mode 100644 index 000000000..8ebbd4e4b --- /dev/null +++ b/thinc/backends/mps_ops.py @@ -0,0 +1,26 @@ +from typing import TYPE_CHECKING +import numpy + +from .. import registry +from . import NumpyOps, Ops + +if TYPE_CHECKING: + # Type checking does not work with dynamic base classes, since MyPy cannot + # determine against which base class to check. So, always derive from Ops + # during type checking. + _Ops = Ops +else: + try: + from thinc_apple_ops import AppleOps + + _Ops = AppleOps + except ImportError: + _Ops = NumpyOps + + +@registry.ops("MPSOps") +class MPSOps(_Ops): + """Ops class for Metal Performance shaders.""" + + name = "mps" + xp = numpy diff --git a/thinc/backends/numpy_ops.pxd b/thinc/backends/numpy_ops.pxd index 28bab3a31..6cf01fe76 100644 --- a/thinc/backends/numpy_ops.pxd +++ b/thinc/backends/numpy_ops.pxd @@ -1,7 +1,15 @@ +from .cblas cimport saxpy_ptr + ctypedef double[:, ::1] double2d_t ctypedef double[:, :, ::1] double3d_t ctypedef float[:, ::1] float2d_t ctypedef float[:, :, ::1] float3d_t +ctypedef int[:, ::1] int2d_t +ctypedef unsigned int[:, ::1] uint2d_t + +cdef fused ints2d_ft: + int2d_t + uint2d_t cdef fused reals2d_ft: float2d_t @@ -13,6 +21,9 @@ cdef fused reals3d_ft: cdef extern from "cpu_kernels.hh": + cdef cppclass axpy[T]: + ctypedef void (*ptr)(int N, T alpha, const T* X, int incX, T *Y, int incY); + void cpu_maxout[A, L](A* best__bo, L* which__bo, const A* cands_bop, L B, L O, L P) void cpu_backprop_maxout[A, L](A* dX__bop, const A* dX__bo, const L* which__bo, @@ -35,3 +46,5 @@ cdef extern from "cpu_kernels.hh": void cpu_relu[A, L](A* X, L N) void backprop_seq2col[A, L](A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L nW, L nL) void seq2col[A, L](A* output, const A* X, const L* lengths, L nW, L B, L I, L nL) + void cpu_gather_add[F, I, L](axpy[F].ptr axpy, F* out_bo, const F* table_to, const I* indices_bk, + L T, L O, L B, L K) except + diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 9fc6f19e2..c980e6c5d 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -20,6 +20,7 @@ cimport blis.cy from .. import registry from ..util import copy_array, get_array_module from ..types import DeviceTypes, DTypes, Shape, ArrayXd +from .cblas cimport CBlas, daxpy, saxpy from .linalg cimport VecVec, Vec from .ops import Ops @@ -62,19 +63,20 @@ class NumpyOps(Ops): def asarray(self, data, dtype=None): if isinstance(data, self.xp.ndarray): - if dtype is not None: - return self.xp.asarray(data, dtype=dtype) - else: - return self.xp.asarray(data) + array = data elif hasattr(data, 'numpy'): # Handles PyTorch Tensor - return data.numpy() + array = data.numpy() elif hasattr(data, "get"): - return data.get() - elif dtype is not None: - return self.xp.array(data, dtype=dtype) + array = data.get() else: - return self.xp.array(data) + array = self.xp.array(data) + + if dtype is not None: + array = array.astype(dtype=dtype, copy=False) + + return array + def alloc(self, shape: Shape, *, dtype: Optional[DTypes] = "float32", zeros: bool = True) -> ArrayXd: if zeros: @@ -82,6 +84,9 @@ class NumpyOps(Ops): else: return self.xp.empty(shape, dtype=dtype) + def cblas(self) -> CBlas: + return CBlas() + def gemm(self, np.ndarray x, np.ndarray y, *, np.ndarray out=None, trans1=False, trans2=False): if x.ndim != 2: raise ValueError(f"Provided 'x' array should be 2-dimensional, but found {x.ndim} dimension(s).") @@ -432,6 +437,23 @@ class NumpyOps(Ops): return dX + def gather_add(self, reals2d_ft table, ints2d_ft indices): + cdef CBlas cblas = self.cblas() + rows = indices.shape[0] + dims = table.shape[1] + + cdef np.ndarray output + if reals2d_ft is float2d_t: + output = self.xp.zeros((rows, dims), dtype="float32") + cpu_gather_add(saxpy(cblas), output.data, &table[0, 0], &indices[0, 0], + table.shape[0], dims, rows, indices.shape[1]) + else: + output = self.xp.zeros((rows, dims), dtype="float64") + cpu_gather_add(daxpy(cblas), output.data, &table[0, 0], &indices[0, 0], + table.shape[0], dims, rows, indices.shape[1]) + + return output + def scatter_add(self, np.ndarray table, np.ndarray indices, np.ndarray values): if table.dtype == 'float32' \ and indices.dtype == 'int32' \ @@ -452,9 +474,14 @@ class NumpyOps(Ops): @cython.boundscheck(False) @cython.wraparound(False) - def adam(self, np.ndarray weights, np.ndarray gradient, np.ndarray mom1, - np.ndarray mom2, const float beta1, const float beta2, float eps, + def adam(self, np.ndarray[np.float32_t] weights, np.ndarray[np.float32_t] gradient, + np.ndarray[np.float32_t] mom1, np.ndarray[np.float32_t] mom2, + const float beta1, const float beta2, float eps, float learn_rate, float mod_rate=1.): + _check_compatible_shape(weights, gradient) + _check_compatible_shape(weights, mom1) + _check_compatible_shape(weights, mom2) + _adam_momentum(gradient.data, mom1.data, mom2.data, weights.shape[0], beta1, beta2, eps, learn_rate) VecVec.add_i(weights.data, diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index 315b0b0bf..eb9945e27 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -1,6 +1,6 @@ import math -from typing import Optional, List, Tuple, Sequence, Union, cast, TypeVar +from typing import Optional, List, Tuple, Sequence, Type, Union, cast, TypeVar from typing import Iterator, overload import numpy import itertools @@ -9,13 +9,14 @@ from ..types import Floats1d, Floats2d, Floats3d, Floats4d from ..types import Array1d, Array2d, Array3d, Array4d, ListXd from ..types import FloatsXd, Ints1d, Ints2d, Ints3d, Ints4d, IntsXd, _Floats +from ..types import FloatsXdT from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator from ..util import get_array_module, is_xp_array, to_numpy +from .cblas import CBlas ArrayT = TypeVar("ArrayT", bound=ArrayXd) FloatsT = TypeVar("FloatsT", bound=_Floats) -FloatsType = TypeVar("FloatsType", bound=FloatsXd) SQRT2PI = math.sqrt(2.0 / math.pi) INV_SQRT2 = 1.0 / math.sqrt(2.0) INV_SQRT_2PI = 1.0 / math.sqrt(2.0 * math.pi) @@ -31,6 +32,11 @@ def __init__( self.device_type = device_type self.device_id = device_id + def cblas(self) -> CBlas: + """Return C BLAS function table.""" + err = f"{type(self).__name__} does not provide C BLAS functions" + raise NotImplementedError(err) + def to_numpy(self, data, *, byte_order=None): # pragma: no cover if isinstance(data, numpy.ndarray): if byte_order: @@ -223,56 +229,56 @@ def affine(self, X: Floats2d, W: Floats2d, b: Floats1d) -> Floats2d: Y += b return Y - @overload + @overload def flatten( self, X: List[Floats2d], dtype: Optional[DTypes] = None, pad: int = 0, ndim_if_empty: int = 2, - ) -> Floats2d: + ) -> Floats2d: ... - @overload + @overload def flatten( self, X: List[Ints1d], dtype: Optional[DTypes] = None, pad: int = 0, ndim_if_empty: int = 2, - ) -> Ints1d: + ) -> Ints1d: ... - @overload + @overload def flatten( self, X: List2d, dtype: Optional[DTypes] = None, pad: int = 0, ndim_if_empty: int = 2, - ) -> Array2d: + ) -> Array2d: ... # further specific typed signatures can be added as necessary - @overload + @overload def flatten( self, X: ListXd, dtype: Optional[DTypes] = None, pad: int = 0, ndim_if_empty: int = 2, - ) -> ArrayXd: + ) -> ArrayXd: ... - @overload + @overload def flatten( self, X: Sequence[ArrayXd], dtype: Optional[DTypes] = None, pad: int = 0, ndim_if_empty: int = 2, - ) -> ArrayXd: + ) -> ArrayXd: ... def flatten( @@ -367,7 +373,7 @@ def pad( # noqa: F811 # array sizes. length = (length + (round_to - 1)) // round_to * round_to final_shape = (len(seqs), length) + seqs[0].shape[1:] - output: Array3d = self.alloc(final_shape, dtype=seqs[0].dtype) + output: Array3d = cast(Array3d, self.alloc(final_shape, dtype=seqs[0].dtype)) for i, arr in enumerate(seqs): # It's difficult to convince this that the dtypes will match. output[i, : arr.shape[0]] = arr # type: ignore[assignment, call-overload] @@ -445,7 +451,7 @@ def get_dropout_mask(self, shape: Shape, drop: Optional[float]) -> FloatsXd: if drop is None or drop <= 0: return self.xp.ones(shape, dtype="f") elif drop >= 1.0: - return self.alloc(shape) + return self.alloc_f(shape) coinflips = self.xp.random.uniform(0.0, 1.0, shape) mask = (coinflips >= drop) / (1.0 - drop) return cast(FloatsXd, self.asarray(mask, dtype="float32")) @@ -457,7 +463,7 @@ def alloc1f( dtype: Optional[DTypesFloat] = "float32", zeros: bool = True, ) -> Floats1d: - return self.alloc((d0,), dtype=dtype, zeros=zeros) + return cast(Floats1d, self.alloc((d0,), dtype=dtype, zeros=zeros)) def alloc2f( self, @@ -467,7 +473,7 @@ def alloc2f( dtype: Optional[DTypesFloat] = "float32", zeros: bool = True, ) -> Floats2d: - return self.alloc((d0, d1), dtype=dtype, zeros=zeros) + return cast(Floats2d, self.alloc((d0, d1), dtype=dtype, zeros=zeros)) def alloc3f( self, @@ -478,7 +484,7 @@ def alloc3f( dtype: Optional[DTypesFloat] = "float32", zeros: bool = True, ) -> Floats3d: - return self.alloc((d0, d1, d2), dtype=dtype, zeros=zeros) + return cast(Floats3d, self.alloc((d0, d1, d2), dtype=dtype, zeros=zeros)) def alloc4f( self, @@ -490,7 +496,7 @@ def alloc4f( dtype: Optional[DTypesFloat] = "float32", zeros: bool = True, ) -> Floats4d: - return self.alloc((d0, d1, d2, d3), dtype=dtype, zeros=zeros) + return cast(Floats4d, self.alloc((d0, d1, d2, d3), dtype=dtype, zeros=zeros)) def alloc_f( self, @@ -499,7 +505,7 @@ def alloc_f( dtype: Optional[DTypesFloat] = "float32", zeros: bool = True, ) -> FloatsXd: - return self.alloc(shape, dtype=dtype, zeros=zeros) + return cast(FloatsXd, self.alloc(shape, dtype=dtype, zeros=zeros)) def alloc1i( self, @@ -508,7 +514,7 @@ def alloc1i( dtype: Optional[DTypesInt] = "int32", zeros: bool = True, ) -> Ints1d: - return self.alloc((d0,), dtype=dtype, zeros=zeros) + return cast(Ints1d, self.alloc((d0,), dtype=dtype, zeros=zeros)) def alloc2i( self, @@ -518,7 +524,7 @@ def alloc2i( dtype: Optional[DTypesInt] = "int32", zeros: bool = True, ) -> Ints2d: - return self.alloc((d0, d1), dtype=dtype, zeros=zeros) + return cast(Ints2d, self.alloc((d0, d1), dtype=dtype, zeros=zeros)) def alloc3i( self, @@ -529,7 +535,7 @@ def alloc3i( dtype: Optional[DTypesInt] = "int32", zeros: bool = True, ) -> Ints3d: - return self.alloc((d0, d1, d2), dtype=dtype, zeros=zeros) + return cast(Ints3d, self.alloc((d0, d1, d2), dtype=dtype, zeros=zeros)) def alloc4i( self, @@ -541,7 +547,7 @@ def alloc4i( dtype: Optional[DTypesInt] = "int32", zeros: bool = True, ) -> Ints4d: - return self.alloc((d0, d1, d2, d3), dtype=dtype, zeros=zeros) + return cast(Ints4d, self.alloc((d0, d1, d2, d3), dtype=dtype, zeros=zeros)) def alloc_i( self, @@ -550,7 +556,7 @@ def alloc_i( dtype: Optional[DTypesInt] = "int32", zeros: bool = True, ) -> IntsXd: - return self.alloc(shape, dtype=dtype, zeros=zeros) + return cast(IntsXd, self.alloc(shape, dtype=dtype, zeros=zeros)) def alloc( self, @@ -558,7 +564,7 @@ def alloc( *, dtype: Optional[DTypes] = "float32", zeros: bool = True, - ) -> ArrayT: + ) -> ArrayXd: """Allocate an array of a certain shape.""" if isinstance(shape, int): shape = (shape,) @@ -620,7 +626,7 @@ def reshape(self, array: ArrayT, shape: Shape) -> ArrayT: def asarray4f( self, - data: Union[Floats4d, Sequence[int]], + data: Union[Floats4d, Sequence[float]], *, dtype: Optional[DTypes] = "float32", ) -> Floats4d: @@ -628,7 +634,7 @@ def asarray4f( def asarray3f( self, - data: Union[Floats3d, Sequence[int]], + data: Union[Floats3d, Sequence[float]], *, dtype: Optional[DTypes] = "float32", ) -> Floats3d: @@ -636,7 +642,7 @@ def asarray3f( def asarray2f( self, - data: Union[Floats2d, Sequence[int]], + data: Union[Floats2d, Sequence[float]], *, dtype: Optional[DTypes] = "float32", ) -> Floats2d: @@ -644,7 +650,7 @@ def asarray2f( def asarray1f( self, - data: Union[Floats1d, Sequence[int]], + data: Union[Floats1d, Sequence[float]], *, dtype: Optional[DTypes] = "float32", ) -> Floats1d: @@ -715,29 +721,29 @@ def as_contig(self, data: ArrayT, dtype: Optional[DTypes] = None) -> ArrayT: kwargs = {"dtype": dtype} if dtype is not None else {} return self.xp.ascontiguousarray(data, **kwargs) - def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType: - # To prevent overflows and help with regularization/numerical stability - X = self.xp.clip(X, -20.0, 20.0) - + def sigmoid(self, X: FloatsXdT, *, inplace: bool = False) -> FloatsXdT: if inplace: + # To prevent overflows and help with regularization/numerical stability + X = self.xp.clip(X, -20.0, 20.0, out=X) self.xp.exp(-X, out=X) - X += 1.0 # type: ignore[assignment] - X **= -1.0 # type: ignore[assignment] - return cast(FloatsType, X) + X += 1.0 + X **= -1.0 + return X else: - return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X))) + X = self.xp.clip(X, -20.0, 20.0) + return 1.0 / (1.0 + self.xp.exp(-X)) def backprop_sigmoid( - self, dY: FloatsType, Y: FloatsType, *, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, Y: FloatsXdT, *, inplace: bool = False + ) -> FloatsXdT: if inplace: self.dsigmoid(Y, inplace=True) - Y *= dY # type: ignore + Y *= dY return Y else: - return dY * self.dsigmoid(Y, inplace=inplace) # type: ignore + return dY * self.dsigmoid(Y, inplace=inplace) - def dsigmoid(self, Y: FloatsType, *, inplace: bool = False) -> FloatsType: + def dsigmoid(self, Y: FloatsXdT, *, inplace: bool = False) -> FloatsXdT: if inplace: Y *= 1 - Y return Y @@ -858,30 +864,30 @@ def backprop_relu( def clipped_linear( self, - X: FloatsType, + X: FloatsXdT, slope: float = 1.0, offset: float = 0.0, min_val: float = 0.0, max_val: float = 1.0, inplace: bool = False, - ) -> FloatsType: + ) -> FloatsXdT: if inplace: - X *= slope # type: ignore[assignment] - X += offset # type: ignore[assignment] - return cast(FloatsType, self.xp.clip(X, min_val, max_val, out=X)) - out = X * slope + offset # type: ignore[assignment] - return cast(FloatsType, self.xp.clip(out, min_val, max_val)) + X *= slope + X += offset + return self.xp.clip(X, min_val, max_val, out=X) + out = X * slope + offset + return self.xp.clip(out, min_val, max_val) def backprop_clipped_linear( self, - dY: FloatsType, - X: FloatsType, + dY: FloatsXdT, + X: FloatsXdT, slope: float = 1.0, offset: float = 0.0, min_val: float = 0.0, max_val: float = 1.0, inplace: bool = False, - ) -> FloatsType: + ) -> FloatsXdT: low = (min_val - offset) / slope high = (max_val - offset) / slope slope = self.xp.float64(slope).astype(X.dtype) @@ -892,60 +898,58 @@ def backprop_clipped_linear( return dY return dY * dX - def relu_k( - self, X: FloatsType, n: float = 6.0, inplace: bool = False - ) -> FloatsType: + def relu_k(self, X: FloatsXdT, n: float = 6.0, inplace: bool = False) -> FloatsXdT: return self.clipped_linear(X, max_val=n, inplace=inplace) def backprop_relu_k( - self, dY: FloatsType, X: FloatsType, n: float = 6.0, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, n: float = 6.0, inplace: bool = False + ) -> FloatsXdT: return self.backprop_clipped_linear(dY, X, max_val=n, inplace=inplace) - def hard_sigmoid(self, X: FloatsType, inplace: bool = False) -> FloatsType: - return self.clipped_linear(X, slope=0.2, offset=0.5) + def hard_sigmoid(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: + return self.clipped_linear(X, slope=0.2, offset=0.5, inplace=inplace) def backprop_hard_sigmoid( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: return self.backprop_clipped_linear(dY, X, slope=0.2, offset=0.5) - def hard_tanh(self, X: FloatsType, inplace: bool = False) -> FloatsType: - return self.clipped_linear(X, min_val=-1.0, max_val=1.0) + def hard_tanh(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: + return self.clipped_linear(X, min_val=-1.0, max_val=1.0, inplace=inplace) def backprop_hard_tanh( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: return self.backprop_clipped_linear(dY, X, min_val=-1.0, max_val=1.0) - def swish(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def swish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: if inplace: - X *= self.sigmoid(X) # type: ignore[operator, assignment] - return cast(FloatsType, X) - out = X * self.sigmoid(X) # type: ignore[operator] - return cast(FloatsType, out) + X *= self.sigmoid(X) + return X + out = X * self.sigmoid(X) + return out def backprop_swish( - self, dY: FloatsType, X: FloatsType, Y: FloatsType, inplace: bool = False - ) -> FloatsType: - Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore[operator] + self, dY: FloatsXdT, X: FloatsXdT, Y: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: + Y = Y + self.sigmoid(X) * (1 - Y) if inplace: - dY *= Y # type: ignore[operator, assignment] - return cast(FloatsType, dY) - out = dY * Y # type: ignore[operator] - return cast(FloatsType, out) + dY *= Y + return dY + out = dY * Y + return out # Following https://www.scitepress.org/Papers/2019/74696/74696.pdf - def hard_swish(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def hard_swish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: if inplace: - X *= self.hard_sigmoid(X) # type: ignore[operator, assignment] - return cast(FloatsType, X) - out = X * self.hard_sigmoid(X) # type: ignore[operator] - return cast(FloatsType, out) + X *= self.hard_sigmoid(X) + return X + out = X * self.hard_sigmoid(X) + return out def backprop_hard_swish( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: dX = X * 0.4 + 0.5 dX[X > 2.5] = 1.0 dX[X < -2.5] = 0 @@ -955,15 +959,15 @@ def backprop_hard_swish( return dY * dX # From https://arxiv.org/pdf/1905.02244v5.pdf - def hard_swish_mobilenet(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def hard_swish_mobilenet(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: if inplace: X *= self.relu_k(X + 3) / 6 return X return X * (self.relu_k(X + 3) / 6) def backprop_hard_swish_mobilenet( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: dX = (1 / 6) * (X * 2.0 + 3.0) dX[X > 3.0] = 1.0 dX[X < -3.0] = 0 @@ -972,9 +976,38 @@ def backprop_hard_swish_mobilenet( return dY return dX * dY + def dish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: + tmp = self.xp.square(X) + tmp += 1.0 + self.xp.sqrt(tmp, out=tmp) + tmp = X / tmp + tmp += 1 + tmp *= 0.5 + if inplace: + X *= tmp + return X + else: + return X * tmp + + def backprop_dish( + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: + x_sq = self.xp.square(X) + x_sq_plus_one = x_sq + 1.0 + deriv = X / self.xp.sqrt(x_sq_plus_one) + second = 0.5 * X * x_sq + second /= x_sq_plus_one**1.5 + deriv -= second + deriv += 0.5 + if inplace: + dY *= deriv + return dY + else: + return dY * deriv + # Code snippet taken from: # https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/ - def erf(self, X: FloatsType) -> FloatsType: + def erf(self, X: FloatsXdT) -> FloatsXdT: # save the sign of x sign = self.xp.sign(X) X = self.xp.abs(X) @@ -994,10 +1027,12 @@ def erf(self, X: FloatsType) -> FloatsType: out = out.astype(X.dtype) return out - def sechsq(self, X: FloatsType) -> FloatsType: + def sechsq(self, X: FloatsXdT) -> FloatsXdT: + # Avoid overflow in cosh. Clipping at |20| has an error of 1.7e-17. + X = self.xp.clip(X, -20.0, 20.0) return (1 / self.xp.cosh(X)) ** 2 - def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def gelu_approx(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: tmp = 1.0 + self.xp.tanh(SQRT2PI * (X + 0.044715 * self.xp.power(X, 3))) tmp *= 0.5 tmp = tmp.astype(X.dtype) @@ -1010,9 +1045,9 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType: return Y def backprop_gelu_approx( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: - dX = self.alloc_f(X.shape) + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: + dX = cast(FloatsXdT, self.alloc_f(X.shape)) Xp3 = self.xp.power(X, 3) tmp = 0.5 * self.xp.tanh(0.0356774 * Xp3 + 0.797885 * X) tmp += (0.0535161 * Xp3 + 0.398942 * X) * self.sechsq( @@ -1025,27 +1060,27 @@ def backprop_gelu_approx( return dY return dY * dX - def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def gelu(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: # GELU(x) = x · Φ(x) cdf = gaussian_cdf(self, X) if inplace: - X *= cdf # type: ignore[operator, assignment] + X *= cdf return X - return X * cdf # type: ignore[operator, return-value] + return X * cdf def backprop_gelu( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: # GELU'(x) = Φ(x) + x · PDF(x) - dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore[operator] + dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) if inplace: dY *= dX return dY return dY * dX def mish( - self, X: FloatsType, threshold: float = 20.0, inplace: bool = False - ) -> FloatsType: + self, X: FloatsXdT, threshold: float = 20.0, inplace: bool = False + ) -> FloatsXdT: tmp = X * self.xp.tanh(self.xp.log(1.0 + self.xp.exp(X))) Y = self.xp.where(X >= threshold, X, tmp) if inplace: @@ -1056,11 +1091,11 @@ def mish( def backprop_mish( self, - dY: FloatsType, + dY: FloatsXdT, X: Floats2d, threshold: float = 20.0, inplace: bool = False, - ) -> FloatsType: + ) -> FloatsXdT: if dY.shape != X.shape: msg = f"arrays have incompatible shapes: {dY.shape} and {X.shape}" raise ValueError(msg) @@ -1106,6 +1141,10 @@ def adam( learn_rate: float, mod_rate: float = 1.0, ) -> Tuple[Floats1d, Floats1d, Floats1d, Floats1d]: + _check_compatible_shape(weights, gradient) + _check_compatible_shape(weights, mom1) + _check_compatible_shape(weights, mom2) + # Internals for optimizer mom1 *= beta1 mom2 *= beta2 @@ -1145,6 +1184,29 @@ def reduce_sum(self, X: Floats2d, lengths: Ints1d) -> Floats2d: Y[i] = 0.0 return Y + def reduce_first(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints1d]: + if lengths.size == 0: + return self.alloc2f(0, X.shape[1]), lengths + if not self.xp.all(lengths > 0): + raise ValueError(f"all sequence lengths must be >= 0") + starts_ends = self.alloc1i(lengths.shape[0] + 1, zeros=False) + starts_ends[0] = 0 + starts_ends[1:] = lengths.cumsum() + if starts_ends[-1] != X.shape[0]: + raise IndexError("lengths must sum up to the number of rows") + + return X[starts_ends[:-1]], starts_ends + + def reduce_last(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints1d]: + if lengths.size == 0: + return self.alloc2f(0, X.shape[1]), lengths + if not self.xp.all(lengths > 0): + raise ValueError(f"all sequence lengths must be >= 0") + lasts = lengths.cumsum() - 1 + if lasts[-1] + 1 != X.shape[0]: + raise IndexError("lengths must sum up to the number of rows") + return X[lasts], lasts + def reduce_mean(self, X: Floats2d, lengths: Ints1d) -> Floats2d: Y = self.alloc2f(lengths.shape[0], X.shape[1], zeros=False) start = 0 @@ -1175,6 +1237,26 @@ def reduce_max(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints2d]: start += length return Y, which + def backprop_reduce_first( + self, d_firsts: Floats2d, starts_ends: Ints1d + ) -> Floats2d: + if starts_ends.size < 2: + raise ValueError(f"starts_ends should least have size 2") + dX = self.alloc2f( + int(starts_ends[-1]), d_firsts.shape[1], dtype=d_firsts.dtype, zeros=True + ) + dX[starts_ends[:-1]] = d_firsts + return dX + + def backprop_reduce_last(self, d_lasts: Floats2d, lasts: Ints1d) -> Floats2d: + if lasts.size < 1: + raise ValueError(f"lasts should least have size 2") + dX = self.alloc2f( + int(lasts[-1]) + 1, d_lasts.shape[1], dtype=d_lasts.dtype, zeros=True + ) + dX[lasts] = d_lasts + return dX + def backprop_reduce_sum(self, d_sums: Floats2d, lengths: Ints1d) -> Floats2d: dX = self.alloc2f( lengths.sum(), d_sums.shape[1], dtype=d_sums.dtype, zeros=False @@ -1242,6 +1324,9 @@ def position_encode( numpy_ops = NumpyOps() return self.asarray2f(numpy_ops.position_encode(N, D, period, out)) + def gather_add(self, table: Floats2d, indices: Ints2d) -> Floats2d: + return table[indices].sum(axis=1) # type: ignore[return-value] + def scatter_add( self, table: FloatsXd, indices: IntsXd, values: FloatsXd ) -> FloatsXd: @@ -1556,11 +1641,17 @@ def dtanh(Y: ArrayT) -> ArrayT: return 1 - Y**2 -def gaussian_cdf(ops: Ops, X: FloatsType) -> FloatsType: +def gaussian_cdf(ops: Ops, X: FloatsXdT) -> FloatsXdT: """Gaussian CDF for distribution with mean 0 and stdev 1.""" return 0.5 * (1.0 + ops.erf(INV_SQRT2 * X)) -def gaussian_pdf(ops: Ops, X: FloatsType) -> FloatsType: +def gaussian_pdf(ops: Ops, X: FloatsXdT) -> FloatsXdT: """Gaussian PDF for distribution with mean 0 and stdev 1.""" return INV_SQRT_2PI * ops.xp.exp(-0.5 * X * X) + + +def _check_compatible_shape(u: FloatsXd, v: FloatsXd): + if u.shape != v.shape: + msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}" + raise ValueError(msg) diff --git a/thinc/compat.py b/thinc/compat.py new file mode 100644 index 000000000..2d8b40345 --- /dev/null +++ b/thinc/compat.py @@ -0,0 +1,90 @@ +from packaging.version import Version + +try: # pragma: no cover + import cupy + import cupyx + + has_cupy = True + cupy_version = Version(cupy.__version__) + try: + cupy.cuda.runtime.getDeviceCount() + has_cupy_gpu = True + except cupy.cuda.runtime.CUDARuntimeError: + has_cupy_gpu = False + + if cupy_version.major >= 10: + # fromDlpack was deprecated in v10.0.0. + cupy_from_dlpack = cupy.from_dlpack + else: + cupy_from_dlpack = cupy.fromDlpack +except (ImportError, AttributeError): + cupy = None + cupyx = None + cupy_version = Version("0.0.0") + has_cupy = False + cupy_from_dlpack = None + has_cupy_gpu = False + + +try: # pragma: no cover + import torch.utils.dlpack + import torch + + has_torch = True + has_torch_cuda_gpu = torch.cuda.device_count() != 0 + has_torch_mps_gpu = ( + hasattr(torch, "has_mps") + and torch.has_mps # type: ignore[attr-defined] + and torch.backends.mps.is_available() # type: ignore[attr-defined] + ) + has_torch_gpu = has_torch_cuda_gpu + torch_version = Version(str(torch.__version__)) + has_torch_amp = ( + torch_version >= Version("1.9.0") + and not torch.cuda.amp.common.amp_definitely_not_available() + ) +except ImportError: # pragma: no cover + torch = None # type: ignore + has_torch = False + has_torch_cuda_gpu = False + has_torch_gpu = False + has_torch_mps_gpu = False + has_torch_amp = False + torch_version = Version("0.0.0") + +try: # pragma: no cover + import tensorflow.experimental.dlpack + import tensorflow + + has_tensorflow = True + has_tensorflow_gpu = len(tensorflow.config.get_visible_devices("GPU")) > 0 +except ImportError: # pragma: no cover + tensorflow = None + has_tensorflow = False + has_tensorflow_gpu = False + + +try: # pragma: no cover + import mxnet + + has_mxnet = True +except ImportError: # pragma: no cover + mxnet = None + has_mxnet = False + +try: + import h5py +except ImportError: # pragma: no cover + h5py = None + + +try: # pragma: no cover + import os_signpost + + has_os_signpost = True +except ImportError: + os_signpost = None + has_os_signpost = False + + +has_gpu = has_cupy_gpu or has_torch_mps_gpu diff --git a/thinc/config.py b/thinc/config.py index 837f91b76..8c0e752c5 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -1,701 +1,10 @@ -from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, cast -from types import GeneratorType -from dataclasses import dataclass -from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH -from configparser import InterpolationMissingOptionError, InterpolationSyntaxError -from configparser import NoSectionError, NoOptionError, InterpolationDepthError -from configparser import ParsingError -from pathlib import Path -from pydantic import BaseModel, create_model, ValidationError, Extra -from pydantic.main import ModelMetaclass -from pydantic.fields import ModelField -from wasabi import table -import srsly import catalogue -import inspect -import io -import numpy -import copy -import re - +import confection +from confection import Config, ConfigValidationError, Promise, VARIABLE_RE from .types import Decorator -# Field used for positional arguments, e.g. [section.*.xyz]. The alias is -# required for the schema (shouldn't clash with user-defined arg names) -ARGS_FIELD = "*" -ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS" -# Aliases for fields that would otherwise shadow pydantic attributes. Can be any -# string, so we're using name + space so it looks the same in error messages etc. -RESERVED_FIELDS = {"validate": "validate\u0020"} -# Internal prefix used to mark section references for custom interpolation -SECTION_PREFIX = "__SECTION__:" -# Values that shouldn't be loaded during interpolation because it'd cause -# even explicit string values to be incorrectly parsed as bools/None etc. -JSON_EXCEPTIONS = ("true", "false", "null") -# Regex to detect whether a value contains a variable -VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") - - -class CustomInterpolation(ExtendedInterpolation): - def before_read(self, parser, section, option, value): - # If we're dealing with a quoted string as the interpolation value, - # make sure we load and unquote it so we don't end up with '"value"' - try: - json_value = srsly.json_loads(value) - if isinstance(json_value, str) and json_value not in JSON_EXCEPTIONS: - value = json_value - except Exception: - pass - return super().before_read(parser, section, option, value) - - def before_get(self, parser, section, option, value, defaults): - # Mostly copy-pasted from the built-in configparser implementation. - L = [] - self.interpolate(parser, option, L, value, section, defaults, 1) - return "".join(L) - - def interpolate(self, parser, option, accum, rest, section, map, depth): - # Mostly copy-pasted from the built-in configparser implementation. - # We need to overwrite this method so we can add special handling for - # block references :( All values produced here should be strings – - # we need to wait until the whole config is interpreted anyways so - # filling in incomplete values here is pointless. All we need is the - # section reference so we can fetch it later. - rawval = parser.get(section, option, raw=True, fallback=rest) - if depth > MAX_INTERPOLATION_DEPTH: - raise InterpolationDepthError(option, section, rawval) - while rest: - p = rest.find("$") - if p < 0: - accum.append(rest) - return - if p > 0: - accum.append(rest[:p]) - rest = rest[p:] - # p is no longer used - c = rest[1:2] - if c == "$": - accum.append("$") - rest = rest[2:] - elif c == "{": - # We want to treat both ${a:b} and ${a.b} the same - m = self._KEYCRE.match(rest) - if m is None: - err = f"bad interpolation variable reference {rest}" - raise InterpolationSyntaxError(option, section, err) - orig_var = m.group(1) - path = orig_var.replace(":", ".").rsplit(".", 1) - rest = rest[m.end() :] - sect = section - opt = option - try: - if len(path) == 1: - opt = parser.optionxform(path[0]) - if opt in map: - v = map[opt] - else: - # We have block reference, store it as a special key - section_name = parser[parser.optionxform(path[0])]._name - v = self._get_section_name(section_name) - elif len(path) == 2: - sect = path[0] - opt = parser.optionxform(path[1]) - fallback = "__FALLBACK__" - v = parser.get(sect, opt, raw=True, fallback=fallback) - # If a variable doesn't exist, try again and treat the - # reference as a section - if v == fallback: - v = self._get_section_name(parser[f"{sect}.{opt}"]._name) - else: - err = f"More than one ':' found: {rest}" - raise InterpolationSyntaxError(option, section, err) - except (KeyError, NoSectionError, NoOptionError): - raise InterpolationMissingOptionError( - option, section, rawval, orig_var - ) from None - if "$" in v: - new_map = dict(parser.items(sect, raw=True)) - self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1) - else: - accum.append(v) - else: - err = "'$' must be followed by '$' or '{', " "found: %r" % (rest,) - raise InterpolationSyntaxError(option, section, err) - - def _get_section_name(self, name: str) -> str: - """Generate the name of a section. Note that we use a quoted string here - so we can use section references within lists and load the list as - JSON. Since section references can't be used within strings, we don't - need the quoted vs. unquoted distinction like we do for variables. - - Examples (assuming section = {"foo": 1}): - - value: ${section.foo} -> value: 1 - - value: "hello ${section.foo}" -> value: "hello 1" - - value: ${section} -> value: {"foo": 1} - - value: "${section}" -> value: {"foo": 1} - - value: "hello ${section}" -> invalid - """ - return f'"{SECTION_PREFIX}{name}"' - - -def get_configparser(interpolate: bool = True): - config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None) - # Preserve case of keys: https://stackoverflow.com/a/1611877/6400719 - config.optionxform = str # type: ignore - return config - - -class Config(dict): - """This class holds the model and training configuration and can load and - save the TOML-style configuration format from/to a string, file or bytes. - The Config class is a subclass of dict and uses Python's ConfigParser - under the hood. - """ - - is_interpolated: bool - - def __init__( - self, - data: Optional[Union[Dict[str, Any], "ConfigParser", "Config"]] = None, - *, - is_interpolated: Optional[bool] = None, - section_order: Optional[List[str]] = None, - ) -> None: - """Initialize a new Config object with optional data.""" - dict.__init__(self) - if data is None: - data = {} - if not isinstance(data, (dict, Config, ConfigParser)): - raise ValueError( - f"Can't initialize Config with data. Expected dict, Config or " - f"ConfigParser but got: {type(data)}" - ) - # Whether the config has been interpolated. We can use this to check - # whether we need to interpolate again when it's resolved. We assume - # that a config is interpolated by default. - if is_interpolated is not None: - self.is_interpolated = is_interpolated - elif isinstance(data, Config): - self.is_interpolated = data.is_interpolated - else: - self.is_interpolated = True - if section_order is not None: - self.section_order = section_order - elif isinstance(data, Config): - self.section_order = data.section_order - else: - self.section_order = [] - # Update with data - self.update(self._sort(data)) - - def interpolate(self) -> "Config": - """Interpolate a config. Returns a copy of the object.""" - # This is currently the most effective way because we need our custom - # to_str logic to run in order to re-serialize the values so we can - # interpolate them again. ConfigParser.read_dict will just call str() - # on all values, which isn't enough. - return Config().from_str(self.to_str()) - - def interpret_config(self, config: "ConfigParser") -> None: - """Interpret a config, parse nested sections and parse the values - as JSON. Mostly used internally and modifies the config in place. - """ - self._validate_sections(config) - # Sort sections by depth, so that we can iterate breadth-first. This - # allows us to check that we're not expanding an undefined block. - get_depth = lambda item: len(item[0].split(".")) - for section, values in sorted(config.items(), key=get_depth): - if section == "DEFAULT": - # Skip [DEFAULT] section so it doesn't cause validation error - continue - parts = section.split(".") - node = self - for part in parts[:-1]: - if part == "*": - node = node.setdefault(part, {}) - elif part not in node: - err_title = f"Error parsing config section. Perhaps a section name is wrong?" - err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}] - raise ConfigValidationError( - config=self, errors=err, title=err_title - ) - else: - node = node[part] - if not isinstance(node, dict): - # Happens if both value *and* subsection were defined for a key - err = [{"loc": parts, "msg": "found conflicting values"}] - err_cfg = f"{self}\n{({part: dict(values)})}" - raise ConfigValidationError(config=err_cfg, errors=err) - # Set the default section - node = node.setdefault(parts[-1], {}) - if not isinstance(node, dict): - # Happens if both value *and* subsection were defined for a key - err = [{"loc": parts, "msg": "found conflicting values"}] - err_cfg = f"{self}\n{({part: dict(values)})}" - raise ConfigValidationError(config=err_cfg, errors=err) - try: - keys_values = list(values.items()) - except InterpolationMissingOptionError as e: - raise ConfigValidationError(desc=f"{e}") from None - for key, value in keys_values: - config_v = config.get(section, key) - node[key] = self._interpret_value(config_v) - self.replace_section_refs(self) - - def replace_section_refs( - self, config: Union[Dict[str, Any], "Config"], parent: str = "" - ) -> None: - """Replace references to section blocks in the final config.""" - for key, value in config.items(): - key_parent = f"{parent}.{key}".strip(".") - if isinstance(value, dict): - self.replace_section_refs(value, parent=key_parent) - elif isinstance(value, list): - config[key] = [ - self._get_section_ref(v, parent=[parent, key]) for v in value - ] - else: - config[key] = self._get_section_ref(value, parent=[parent, key]) - - def _interpret_value(self, value: Any) -> Any: - """Interpret a single config value.""" - result = try_load_json(value) - # If value is a string and it contains a variable, use original value - # (not interpreted string, which could lead to double quotes: - # ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string, - # so we're not keeping lists as strings. - # NOTE: This currently can't handle uninterpolated values like [${x.y}]! - if isinstance(result, str) and VARIABLE_RE.search(value): - result = value - if isinstance(result, list): - return [self._interpret_value(v) for v in result] - return result - - def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any: - """Get a single section reference.""" - if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'): - value = try_load_json(value) - if isinstance(value, str) and value.startswith(SECTION_PREFIX): - parts = value.replace(SECTION_PREFIX, "").split(".") - result = self - for item in parts: - try: - result = result[item] - except (KeyError, TypeError): # This should never happen - err_title = "Error parsing reference to config section" - err_msg = f"Section '{'.'.join(parts)}' is not defined" - err = [{"loc": parts, "msg": err_msg}] - raise ConfigValidationError( - config=self, errors=err, title=err_title - ) from None - return result - elif isinstance(value, str) and SECTION_PREFIX in value: - # String value references a section (either a dict or return - # value of promise). We can't allow this, since variables are - # always interpolated *before* configs are resolved. - err_desc = ( - "Can't reference whole sections or return values of function " - "blocks inside a string or list\n\nYou can change your variable to " - "reference a value instead. Keep in mind that it's not " - "possible to interpolate the return value of a registered " - "function, since variables are interpolated when the config " - "is loaded, and registered functions are resolved afterwards." - ) - err = [{"loc": parent, "msg": "uses section variable in string or list"}] - raise ConfigValidationError(errors=err, desc=err_desc) - return value - - def copy(self) -> "Config": - """Deepcopy the config.""" - try: - config = copy.deepcopy(self) - except Exception as e: - raise ValueError(f"Couldn't deep-copy config: {e}") from e - return Config( - config, - is_interpolated=self.is_interpolated, - section_order=self.section_order, - ) - - def merge( - self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False - ) -> "Config": - """Deep merge the config with updates, using current as defaults.""" - defaults = self.copy() - updates = Config(updates).copy() - merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra) - return Config( - merged, - is_interpolated=defaults.is_interpolated and updates.is_interpolated, - section_order=defaults.section_order, - ) - - def _sort( - self, data: Union["Config", "ConfigParser", Dict[str, Any]] - ) -> Dict[str, Any]: - """Sort sections using the currently defined sort order. Sort - sections by index on section order, if available, then alphabetic, and - account for subsections, which should always follow their parent. - """ - sort_map = {section: i for i, section in enumerate(self.section_order)} - sort_key = lambda x: ( - sort_map.get(x[0].split(".")[0], len(sort_map)), - _mask_positional_args(x[0]), - ) - return dict(sorted(data.items(), key=sort_key)) - - def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None: - """Set overrides in the ConfigParser before config is interpreted.""" - err_title = "Error parsing config overrides" - for key, value in overrides.items(): - err_msg = "not a section value that can be overwritten" - err = [{"loc": key.split("."), "msg": err_msg}] - if "." not in key: - raise ConfigValidationError(errors=err, title=err_title) - section, option = key.rsplit(".", 1) - # Check for section and accept if option not in config[section] - if section not in config: - raise ConfigValidationError(errors=err, title=err_title) - config.set(section, option, try_dump_json(value, overrides)) - - def _validate_sections(self, config: "ConfigParser") -> None: - # If the config defines top-level properties that are not sections (e.g. - # if config was constructed from dict), those values would be added as - # [DEFAULTS] and included in *every other section*. This is usually not - # what we want and it can lead to very confusing results. - default_section = config.defaults() - if default_section: - err_title = "Found config values without a top-level section" - err_msg = "not part of a section" - err = [{"loc": [k], "msg": err_msg} for k in default_section] - raise ConfigValidationError(errors=err, title=err_title) - - def from_str( - self, text: str, *, interpolate: bool = True, overrides: Dict[str, Any] = {} - ) -> "Config": - """Load the config from a string.""" - config = get_configparser(interpolate=interpolate) - if overrides: - config = get_configparser(interpolate=False) - try: - config.read_string(text) - except ParsingError as e: - desc = f"Make sure the sections and values are formatted correctly.\n\n{e}" - raise ConfigValidationError(desc=desc) from None - config._sections = self._sort(config._sections) - self._set_overrides(config, overrides) - self.clear() - self.interpret_config(config) - if overrides and interpolate: - # do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty - self = self.interpolate() - self.is_interpolated = interpolate - return self - - def to_str(self, *, interpolate: bool = True) -> str: - """Write the config to a string.""" - flattened = get_configparser(interpolate=interpolate) - queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)] - for path, node in queue: - section_name = ".".join(path) - is_kwarg = path and path[-1] != "*" - if is_kwarg and not flattened.has_section(section_name): - # Always create sections for non-'*' sections, not only if - # they have leaf entries, as we don't want to expand - # blocks that are undefined - flattened.add_section(section_name) - for key, value in node.items(): - if hasattr(value, "items"): - # Reference to a function with no arguments, serialize - # inline as a dict and don't create new section - if registry.is_promise(value) and len(value) == 1 and is_kwarg: - flattened.set(section_name, key, try_dump_json(value, node)) - else: - queue.append((path + (key,), value)) - else: - flattened.set(section_name, key, try_dump_json(value, node)) - # Order so subsection follow parent (not all sections, then all subs etc.) - flattened._sections = self._sort(flattened._sections) - self._validate_sections(flattened) - string_io = io.StringIO() - flattened.write(string_io) - return string_io.getvalue().strip() - - def to_bytes(self, *, interpolate: bool = True) -> bytes: - """Serialize the config to a byte string.""" - return self.to_str(interpolate=interpolate).encode("utf8") - - def from_bytes( - self, - bytes_data: bytes, - *, - interpolate: bool = True, - overrides: Dict[str, Any] = {}, - ) -> "Config": - """Load the config from a byte string.""" - return self.from_str( - bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides - ) - - def to_disk(self, path: Union[str, Path], *, interpolate: bool = True): - """Serialize the config to a file.""" - path = Path(path) if isinstance(path, str) else path - with path.open("w", encoding="utf8") as file_: - file_.write(self.to_str(interpolate=interpolate)) - - def from_disk( - self, - path: Union[str, Path], - *, - interpolate: bool = True, - overrides: Dict[str, Any] = {}, - ) -> "Config": - """Load config from a file.""" - path = Path(path) if isinstance(path, str) else path - with path.open("r", encoding="utf8") as file_: - text = file_.read() - return self.from_str(text, interpolate=interpolate, overrides=overrides) - - -def _mask_positional_args(name: str) -> List[Optional[str]]: - """Create a section name representation that masks names - of positional arguments to retain their order in sorts.""" - - stable_name = cast(List[Optional[str]], name.split(".")) - - # Remove names of sections that are a positional argument. - for i in range(1, len(stable_name)): - if stable_name[i - 1] == "*": - stable_name[i] = None - - return stable_name - - -def try_load_json(value: str) -> Any: - """Load a JSON string if possible, otherwise default to original value.""" - try: - return srsly.json_loads(value) - except Exception: - return value - - -def try_dump_json(value: Any, data: Union[Dict[str, dict], Config, str] = "") -> str: - """Dump a config value as JSON and output user-friendly error if it fails.""" - # Special case if we have a variable: it's already a string so don't dump - # to preserve ${x:y} vs. "${x:y}" - if isinstance(value, str) and VARIABLE_RE.search(value): - return value - if isinstance(value, str) and value.replace(".", "", 1).isdigit(): - # Work around values that are strings but numbers - value = f'"{value}"' - try: - return srsly.json_dumps(value) - except Exception as e: - err_msg = ( - f"Couldn't serialize config value of type {type(value)}: {e}. Make " - f"sure all values in your config are JSON-serializable. If you want " - f"to include Python objects, use a registered function that returns " - f"the object instead." - ) - raise ConfigValidationError(config=data, desc=err_msg) from e - - -def deep_merge_configs( - config: Union[Dict[str, Any], Config], - defaults: Union[Dict[str, Any], Config], - *, - remove_extra: bool = False, -) -> Union[Dict[str, Any], Config]: - """Deep merge two configs.""" - if remove_extra: - # Filter out values in the original config that are not in defaults - keys = list(config.keys()) - for key in keys: - if key not in defaults: - del config[key] - for key, value in defaults.items(): - if isinstance(value, dict): - node = config.setdefault(key, {}) - if not isinstance(node, dict): - continue - value_promises = [k for k in value if k.startswith("@")] - value_promise = value_promises[0] if value_promises else None - node_promises = [k for k in node if k.startswith("@")] if node else [] - node_promise = node_promises[0] if node_promises else None - # We only update the block from defaults if it refers to the same - # registered function - if ( - value_promise - and node_promise - and ( - value_promise in node - and node[value_promise] != value[value_promise] - ) - ): - continue - if node_promise and ( - node_promise not in value or node[node_promise] != value[node_promise] - ): - continue - defaults = deep_merge_configs(node, value, remove_extra=remove_extra) - elif key not in config: - config[key] = value - return config - - -class ConfigValidationError(ValueError): - def __init__( - self, - *, - config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None, - errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(), - title: Optional[str] = "Config validation error", - desc: Optional[str] = None, - parent: Optional[str] = None, - show_config: bool = True, - ) -> None: - """Custom error for validating configs. - - config (Union[Config, Dict[str, Dict[str, Any]], str]): The - config the validation error refers to. - errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]): - A list of errors as dicts with keys "loc" (list of strings - describing the path of the value), "msg" (validation message - to show) and optional "type" (mostly internals). - Same format as produced by pydantic's validation error (e.errors()). - title (str): The error title. - desc (str): Optional error description, displayed below the title. - parent (str): Optional parent to use as prefix for all error locations. - For example, parent "element" will result in "element -> a -> b". - show_config (bool): Whether to print the whole config with the error. - - ATTRIBUTES: - config (Union[Config, Dict[str, Dict[str, Any]], str]): The config. - errors (Iterable[Dict[str, Any]]): The errors. - error_types (Set[str]): All "type" values defined in the errors, if - available. This is most relevant for the pydantic errors that define - types like "type_error.integer". This attribute makes it easy to - check if a config validation error includes errors of a certain - type, e.g. to log additional information or custom help messages. - title (str): The title. - desc (str): The description. - parent (str): The parent. - show_config (bool): Whether to show the config. - text (str): The formatted error text. - """ - self.config = config - self.errors = errors - self.title = title - self.desc = desc - self.parent = parent - self.show_config = show_config - self.error_types = set() - for error in self.errors: - err_type = error.get("type") - if err_type: - self.error_types.add(err_type) - self.text = self._format() - ValueError.__init__(self, self.text) - - @classmethod - def from_error( - cls, - err: "ConfigValidationError", - title: Optional[str] = None, - desc: Optional[str] = None, - parent: Optional[str] = None, - show_config: Optional[bool] = None, - ) -> "ConfigValidationError": - """Create a new ConfigValidationError based on an existing error, e.g. - to re-raise it with different settings. If no overrides are provided, - the values from the original error are used. - - err (ConfigValidationError): The original error. - title (str): Overwrite error title. - desc (str): Overwrite error description. - parent (str): Overwrite error parent. - show_config (bool): Overwrite whether to show config. - RETURNS (ConfigValidationError): The new error. - """ - return cls( - config=err.config, - errors=err.errors, - title=title if title is not None else err.title, - desc=desc if desc is not None else err.desc, - parent=parent if parent is not None else err.parent, - show_config=show_config if show_config is not None else err.show_config, - ) - - def _format(self) -> str: - """Format the error message.""" - loc_divider = "->" - data = [] - for error in self.errors: - err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])]) - if self.parent: - err_loc = f"{self.parent} {loc_divider} {err_loc}" - data.append((err_loc, error.get("msg"))) - result = [] - if self.title: - result.append(self.title) - if self.desc: - result.append(self.desc) - if data: - result.append(table(data)) - if self.config and self.show_config: - result.append(f"{self.config}") - return "\n\n" + "\n".join(result) - - -def alias_generator(name: str) -> str: - """Generate field aliases in promise schema.""" - # Underscore fields are not allowed in model, so use alias - if name == ARGS_FIELD_ALIAS: - return ARGS_FIELD - # Auto-alias fields that shadow base model attributes - if name in RESERVED_FIELDS: - return RESERVED_FIELDS[name] - return name - - -def copy_model_field(field: ModelField, type_: Any) -> ModelField: - """Copy a model field and assign a new type, e.g. to accept an Any type - even though the original value is typed differently. - """ - return ModelField( - name=field.name, - type_=type_, - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - default_factory=field.default_factory, - required=field.required, - ) - - -class EmptySchema(BaseModel): - class Config: - extra = "allow" - arbitrary_types_allowed = True - - -class _PromiseSchemaConfig: - extra = "forbid" - arbitrary_types_allowed = True - alias_generator = alias_generator - - -@dataclass -class Promise: - registry: str - name: str - args: List[str] - kwargs: Dict[str, Any] - - -class registry(object): +class registry(confection.registry): # fmt: off optimizers: Decorator = catalogue.create("thinc", "optimizers", entry_points=True) schedules: Decorator = catalogue.create("thinc", "schedules", entry_points=True) @@ -716,346 +25,5 @@ def create(cls, registry_name: str, entry_points: bool = False) -> None: ) setattr(cls, registry_name, reg) - @classmethod - def has(cls, registry_name: str, func_name: str) -> bool: - """Check whether a function is available in a registry.""" - if not hasattr(cls, registry_name): - return False - reg = getattr(cls, registry_name) - return func_name in reg - - @classmethod - def get(cls, registry_name: str, func_name: str) -> Callable: - """Get a registered function from a given registry.""" - if not hasattr(cls, registry_name): - raise ValueError(f"Unknown registry: '{registry_name}'") - reg = getattr(cls, registry_name) - func = reg.get(func_name) - if func is None: - raise ValueError(f"Could not find '{func_name}' in '{registry_name}'") - return func - - @classmethod - def resolve( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - *, - schema: Type[BaseModel] = EmptySchema, - overrides: Dict[str, Any] = {}, - validate: bool = True, - ) -> Dict[str, Any]: - resolved, _ = cls._make( - config, schema=schema, overrides=overrides, validate=validate, resolve=True - ) - return resolved - - @classmethod - def fill( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - *, - schema: Type[BaseModel] = EmptySchema, - overrides: Dict[str, Any] = {}, - validate: bool = True, - ): - _, filled = cls._make( - config, schema=schema, overrides=overrides, validate=validate, resolve=False - ) - return filled - - @classmethod - def _make( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - *, - schema: Type[BaseModel] = EmptySchema, - overrides: Dict[str, Any] = {}, - resolve: bool = True, - validate: bool = True, - ) -> Tuple[Dict[str, Any], Config]: - """Unpack a config dictionary and create two versions of the config: - a resolved version with objects from the registry created recursively, - and a filled version with all references to registry functions left - intact, but filled with all values and defaults based on the type - annotations. If validate=True, the config will be validated against the - type annotations of the registered functions referenced in the config - (if available) and/or the schema (if available). - """ - # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} - # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} - if cls.is_promise(config): - err_msg = "The top-level config object can't be a reference to a registered function." - raise ConfigValidationError(config=config, errors=[{"msg": err_msg}]) - # If a Config was loaded with interpolate=False, we assume it needs to - # be interpolated first, otherwise we take it at face value - is_interpolated = not isinstance(config, Config) or config.is_interpolated - section_order = config.section_order if isinstance(config, Config) else None - orig_config = config - if not is_interpolated: - config = Config(orig_config).interpolate() - filled, _, resolved = cls._fill( - config, schema, validate=validate, overrides=overrides, resolve=resolve - ) - filled = Config(filled, section_order=section_order) - # Check that overrides didn't include invalid properties not in config - if validate: - cls._validate_overrides(filled, overrides) - # Merge the original config back to preserve variables if we started - # with a config that wasn't interpolated. Here, we prefer variables to - # allow auto-filling a non-interpolated config without destroying - # variable references. - if not is_interpolated: - filled = filled.merge( - Config(orig_config, is_interpolated=False), remove_extra=True - ) - return dict(resolved), filled - - @classmethod - def _fill( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - schema: Type[BaseModel] = EmptySchema, - *, - validate: bool = True, - resolve: bool = True, - parent: str = "", - overrides: Dict[str, Dict[str, Any]] = {}, - ) -> Tuple[ - Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] - ]: - """Build three representations of the config: - 1. All promises are preserved (just like config user would provide). - 2. Promises are replaced by their return values. This is the validation - copy and will be parsed by pydantic. It lets us include hacks to - work around problems (e.g. handling of generators). - 3. Final copy with promises replaced by their return values. - """ - filled: Dict[str, Any] = {} - validation: Dict[str, Any] = {} - final: Dict[str, Any] = {} - for key, value in config.items(): - # If the field name is reserved, we use its alias for validation - v_key = RESERVED_FIELDS.get(key, key) - key_parent = f"{parent}.{key}".strip(".") - if key_parent in overrides: - value = overrides[key_parent] - config[key] = value - if cls.is_promise(value): - if key in schema.__fields__ and not resolve: - # If we're not resolving the config, make sure that the field - # expecting the promise is typed Any so it doesn't fail - # validation if it doesn't receive the function return value - field = schema.__fields__[key] - schema.__fields__[key] = copy_model_field(field, Any) - promise_schema = cls.make_promise_schema(value, resolve=resolve) - filled[key], validation[v_key], final[key] = cls._fill( - value, - promise_schema, - validate=validate, - resolve=resolve, - parent=key_parent, - overrides=overrides, - ) - reg_name, func_name = cls.get_constructor(final[key]) - args, kwargs = cls.parse_args(final[key]) - if resolve: - # Call the function and populate the field value. We can't - # just create an instance of the type here, since this - # wouldn't work for generics / more complex custom types - getter = cls.get(reg_name, func_name) - # We don't want to try/except this and raise our own error - # here, because we want the traceback if the function fails. - getter_result = getter(*args, **kwargs) - else: - # We're not resolving and calling the function, so replace - # the getter_result with a Promise class - getter_result = Promise( - registry=reg_name, name=func_name, args=args, kwargs=kwargs - ) - validation[v_key] = getter_result - final[key] = getter_result - if isinstance(validation[v_key], GeneratorType): - # If value is a generator we can't validate type without - # consuming it (which doesn't work if it's infinite – see - # schedule for examples). So we skip it. - validation[v_key] = [] - elif hasattr(value, "items"): - field_type = EmptySchema - if key in schema.__fields__: - field = schema.__fields__[key] - field_type = field.type_ - if not isinstance(field.type_, ModelMetaclass): - # If we don't have a pydantic schema and just a type - field_type = EmptySchema - filled[key], validation[v_key], final[key] = cls._fill( - value, - field_type, - validate=validate, - resolve=resolve, - parent=key_parent, - overrides=overrides, - ) - if key == ARGS_FIELD and isinstance(validation[v_key], dict): - # If the value of variable positional args is a dict (e.g. - # created via config blocks), only use its values - validation[v_key] = list(validation[v_key].values()) - final[key] = list(final[key].values()) - else: - filled[key] = value - # Prevent pydantic from consuming generator if part of a union - validation[v_key] = ( - value if not isinstance(value, GeneratorType) else [] - ) - final[key] = value - # Now that we've filled in all of the promises, update with defaults - # from schema, and validate if validation is enabled - exclude = [] - if validate: - try: - result = schema.parse_obj(validation) - except ValidationError as e: - raise ConfigValidationError( - config=config, errors=e.errors(), parent=parent - ) from None - else: - # Same as parse_obj, but without validation - result = schema.construct(**validation) - # If our schema doesn't allow extra values, we need to filter them - # manually because .construct doesn't parse anything - if schema.Config.extra in (Extra.forbid, Extra.ignore): - fields = schema.__fields__.keys() - exclude = [k for k in result.__fields_set__ if k not in fields] - exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) - validation.update(result.dict(exclude=exclude_validation)) - filled, final = cls._update_from_parsed(validation, filled, final) - if exclude: - filled = {k: v for k, v in filled.items() if k not in exclude} - validation = {k: v for k, v in validation.items() if k not in exclude} - final = {k: v for k, v in final.items() if k not in exclude} - return filled, validation, final - - @classmethod - def _update_from_parsed( - cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any] - ): - """Update the final result with the parsed config like converted - values recursively. - """ - for key, value in validation.items(): - if key in RESERVED_FIELDS.values(): - continue # skip aliases for reserved fields - if key not in filled: - filled[key] = value - if key not in final: - final[key] = value - if isinstance(value, dict): - filled[key], final[key] = cls._update_from_parsed( - value, filled[key], final[key] - ) - # Update final config with parsed value if they're not equal (in - # value and in type) but not if it's a generator because we had to - # replace that to validate it correctly - elif key == ARGS_FIELD: - continue # don't substitute if list of positional args - elif isinstance(value, numpy.ndarray): # check numpy first, just in case - final[key] = value - elif ( - value != final[key] or not isinstance(type(value), type(final[key])) - ) and not isinstance(final[key], GeneratorType): - final[key] = value - return filled, final - - @classmethod - def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]): - """Validate overrides against a filled config to make sure there are - no references to properties that don't exist and weren't used.""" - error_msg = "Invalid override: config value doesn't exist" - errors = [] - for override_key in overrides.keys(): - if not cls._is_in_config(override_key, filled): - errors.append({"msg": error_msg, "loc": [override_key]}) - if errors: - raise ConfigValidationError(config=filled, errors=errors) - - @classmethod - def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]): - """Check whether a nested config property like "section.subsection.key" - is in a given config.""" - tree = prop.split(".") - obj = dict(config) - while tree: - key = tree.pop(0) - if isinstance(obj, dict) and key in obj: - obj = obj[key] - else: - return False - return True - - @classmethod - def is_promise(cls, obj: Any) -> bool: - """Check whether an object is a "promise", i.e. contains a reference - to a registered function (via a key starting with `"@"`. - """ - if not hasattr(obj, "keys"): - return False - id_keys = [k for k in obj.keys() if k.startswith("@")] - if len(id_keys): - return True - return False - - @classmethod - def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]: - id_keys = [k for k in obj.keys() if k.startswith("@")] - if len(id_keys) != 1: - err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" - raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}]) - else: - key = id_keys[0] - value = obj[key] - return (key[1:], value) - - @classmethod - def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: - args = [] - kwargs = {} - for key, value in obj.items(): - if not key.startswith("@"): - if key == ARGS_FIELD: - args = value - elif key in RESERVED_FIELDS.values(): - continue - else: - kwargs[key] = value - return args, kwargs - - @classmethod - def make_promise_schema( - cls, obj: Dict[str, Any], *, resolve: bool = True - ) -> Type[BaseModel]: - """Create a schema for a promise dict (referencing a registry function) - by inspecting the function signature. - """ - reg_name, func_name = cls.get_constructor(obj) - if not resolve and not cls.has(reg_name, func_name): - return EmptySchema - func = cls.get(reg_name, func_name) - # Read the argument annotations and defaults from the function signature - id_keys = [k for k in obj.keys() if k.startswith("@")] - sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)} - for param in inspect.signature(func).parameters.values(): - # If no annotation is specified assume it's anything - annotation = param.annotation if param.annotation != param.empty else Any - # If no default value is specified assume that it's required - default = param.default if param.default != param.empty else ... - # Handle spread arguments and use their annotation as Sequence[whatever] - if param.kind == param.VAR_POSITIONAL: - spread_annot = Sequence[annotation] # type: ignore - sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default) - else: - name = RESERVED_FIELDS.get(param.name, param.name) - sig_args[name] = (annotation, default) - sig_args["__config__"] = _PromiseSchemaConfig - return create_model("ArgModel", **sig_args) - __all__ = ["Config", "registry", "ConfigValidationError"] diff --git a/thinc/initializers.py b/thinc/initializers.py index 4842f4f08..1333911a3 100644 --- a/thinc/initializers.py +++ b/thinc/initializers.py @@ -75,7 +75,7 @@ def configure_glorot_uniform_init() -> Callable[[Shape], FloatsXd]: def zero_init(ops: Ops, shape: Shape) -> FloatsXd: - return ops.alloc(shape) + return ops.alloc_f(shape) @registry.initializers("zero_init.v1") diff --git a/thinc/layers/__init__.py b/thinc/layers/__init__.py index b37e38a7a..3a81851b4 100644 --- a/thinc/layers/__init__.py +++ b/thinc/layers/__init__.py @@ -1,5 +1,6 @@ # Weights layers from .cauchysimilarity import CauchySimilarity +from .dish import Dish from .dropout import Dropout from .embed import Embed from .expand_window import expand_window @@ -58,7 +59,7 @@ from .list2padded import list2padded from .ragged2list import ragged2list from .padded2list import padded2list -from .remap_ids import remap_ids +from .remap_ids import remap_ids, remap_ids_v2 from .strings2arrays import strings2arrays from .with_array import with_array from .with_array2d import with_array2d @@ -71,6 +72,7 @@ from .with_getitem import with_getitem from .with_debug import with_debug from .with_nvtx_range import with_nvtx_range +from .with_signpost_interval import with_signpost_interval __all__ = [ @@ -128,5 +130,6 @@ "with_flatten", "with_debug", "with_nvtx_range", + "with_signpost_interval", "remap_ids", ] diff --git a/thinc/layers/dish.py b/thinc/layers/dish.py new file mode 100644 index 000000000..b085946b3 --- /dev/null +++ b/thinc/layers/dish.py @@ -0,0 +1,66 @@ +from typing import Tuple, Optional, Callable, cast + +from ..config import registry +from ..model import Model +from .chain import chain +from .layernorm import LayerNorm +from .dropout import Dropout +from ..types import Floats1d, Floats2d +from ..util import partial, get_width +from ..initializers import he_normal_init, zero_init + + +@registry.layers("Dish.v1") +def Dish( + nO: Optional[int] = None, + nI: Optional[int] = None, + *, + init_W: Callable = he_normal_init, + init_b: Callable = zero_init, + dropout: Optional[float] = None, + normalize: bool = False, +) -> Model[Floats2d, Floats2d]: + model: Model[Floats2d, Floats2d] = Model( + "dish", + forward, + init=partial(init, init_W, init_b), + dims={"nO": nO, "nI": nI}, + params={"W": None, "b": None}, + ) + if normalize: + model = chain(model, LayerNorm(nI=nO)) + if dropout is not None: + model = chain(model, cast(Model[Floats2d, Floats2d], Dropout(dropout))) + return model + + +def forward( + model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool +) -> Tuple[Floats2d, Callable]: + W = cast(Floats2d, model.get_param("W")) + b = cast(Floats1d, model.get_param("b")) + Y_preact = model.ops.affine(X, W, b) + Y = model.ops.dish(Y_preact) + + def backprop(dY: Floats2d) -> Floats2d: + dY = model.ops.backprop_dish(dY, X, inplace=False) + model.inc_grad("b", dY.sum(axis=0)) + model.inc_grad("W", model.ops.gemm(dY, X, trans1=True)) + return model.ops.gemm(dY, W) + + return Y, backprop + + +def init( + init_W: Callable, + init_b: Callable, + model: Model[Floats2d, Floats2d], + X: Optional[Floats2d] = None, + Y: Optional[Floats2d] = None, +) -> None: + if X is not None: + model.set_dim("nI", get_width(X)) + if Y is not None: + model.set_dim("nO", get_width(Y)) + model.set_param("W", init_W(model.ops, (model.get_dim("nO"), model.get_dim("nI")))) + model.set_param("b", init_b(model.ops, (model.get_dim("nO"),))) diff --git a/thinc/layers/gelu.py b/thinc/layers/gelu.py index d49ac77a9..cdb0fb6ee 100644 --- a/thinc/layers/gelu.py +++ b/thinc/layers/gelu.py @@ -34,8 +34,9 @@ def Gelu( return model -def forward(model: Model[Floats2d, Floats2d], - X: Floats2d, is_train: bool) -> Tuple[Floats2d, Callable]: +def forward( + model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool +) -> Tuple[Floats2d, Callable]: W = cast(Floats2d, model.get_param("W")) b = cast(Floats1d, model.get_param("b")) Y_preact = model.ops.affine(X, W, b) diff --git a/thinc/layers/hard_swish.py b/thinc/layers/hard_swish.py index 81b1ad8dd..0478fd270 100644 --- a/thinc/layers/hard_swish.py +++ b/thinc/layers/hard_swish.py @@ -34,8 +34,9 @@ def HardSwish( return model -def forward(model: Model[Floats2d, Floats2d], - X: Floats2d, is_train: bool) -> Tuple[Floats2d, Callable]: +def forward( + model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool +) -> Tuple[Floats2d, Callable]: W = cast(Floats2d, model.get_param("W")) b = cast(Floats1d, model.get_param("b")) Y_preact = model.ops.affine(X, W, b) diff --git a/thinc/layers/hard_swish_mobilenet.py b/thinc/layers/hard_swish_mobilenet.py index 38004c848..6a5dce388 100644 --- a/thinc/layers/hard_swish_mobilenet.py +++ b/thinc/layers/hard_swish_mobilenet.py @@ -34,17 +34,16 @@ def HardSwishMobilenet( return model -def forward(model: Model[Floats2d, Floats2d], - X: Floats2d, is_train: bool) -> Tuple[Floats2d, Callable]: +def forward( + model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool +) -> Tuple[Floats2d, Callable]: W = cast(Floats2d, model.get_param("W")) b = cast(Floats1d, model.get_param("b")) Y_preact = model.ops.affine(X, W, b) Y = model.ops.hard_swish_mobilenet(Y_preact) def backprop(dY: Floats2d) -> Floats2d: - dY = model.ops.backprop_hard_swish_mobilenet(dY, - Y_preact, - inplace=False) + dY = model.ops.backprop_hard_swish_mobilenet(dY, Y_preact, inplace=False) model.inc_grad("b", dY.sum(axis=0)) model.inc_grad("W", model.ops.gemm(dY, X, trans1=True)) return model.ops.gemm(dY, W) diff --git a/thinc/layers/hashembed.py b/thinc/layers/hashembed.py index 74b85c7cf..e2bdc2e68 100644 --- a/thinc/layers/hashembed.py +++ b/thinc/layers/hashembed.py @@ -62,13 +62,13 @@ def forward( nV = vectors.shape[0] nO = vectors.shape[1] if len(ids) == 0: - output: Floats2d = model.ops.alloc((0, nO), dtype=vectors.dtype) + output: Floats2d = model.ops.alloc2f(0, nO, dtype=vectors.dtype) else: ids = model.ops.as_contig(ids, dtype="uint64") nN = ids.shape[0] seed: int = model.attrs["seed"] keys = model.ops.hash(ids, seed) % nV - output = vectors[keys].sum(axis=1) + output = model.ops.gather_add(vectors, keys) drop_mask = None if is_train: dropout: Optional[float] = model.attrs.get("dropout_rate") diff --git a/thinc/layers/layernorm.py b/thinc/layers/layernorm.py index cf22015ed..684489c54 100644 --- a/thinc/layers/layernorm.py +++ b/thinc/layers/layernorm.py @@ -17,7 +17,7 @@ def LayerNorm(nI: Optional[int] = None) -> Model[InT, InT]: forward, init=init, dims={"nI": nI, "nO": nI}, - params={"G": None, "b": None} + params={"G": None, "b": None}, ) diff --git a/thinc/layers/pytorchwrapper.py b/thinc/layers/pytorchwrapper.py index 8e05856bb..882132dcb 100644 --- a/thinc/layers/pytorchwrapper.py +++ b/thinc/layers/pytorchwrapper.py @@ -1,9 +1,10 @@ from typing import Callable, Tuple, Optional, Any, cast +from ..compat import torch from ..model import Model from ..shims import PyTorchGradScaler, PyTorchShim from ..config import registry -from ..util import is_xp_array, is_torch_array +from ..util import is_xp_array, is_torch_array, partial from ..util import xp2torch, torch2xp, convert_recursive from ..types import Floats3d, ArgsKwargs, Padded @@ -76,6 +77,7 @@ def PyTorchWrapper_v2( convert_outputs: Optional[Callable] = None, mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, + device: Optional["torch.device"] = None, ) -> Model[Any, Any]: """Wrap a PyTorch model, so that it has the same API as Thinc models. To optimize the model, you'll need to create a PyTorch optimizer and call @@ -105,6 +107,10 @@ def PyTorchWrapper_v2( The gradient scaler to use for mixed-precision training. If this argument is set to "None" and mixed precision is enabled, a gradient scaler with the default configuration is used. + device: + The PyTorch device to run the model on. When this argument is + set to "None", the default device for the currently active Thinc + ops is used. """ if convert_inputs is None: convert_inputs = convert_pytorch_default_inputs @@ -116,7 +122,10 @@ def PyTorchWrapper_v2( attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs}, shims=[ PyTorchShim( - pytorch_model, mixed_precision=mixed_precision, grad_scaler=grad_scaler + pytorch_model, + mixed_precision=mixed_precision, + grad_scaler=grad_scaler, + device=device, ) ], dims={"nI": None, "nO": None}, @@ -149,7 +158,8 @@ def backprop(dY: Any) -> Any: def convert_pytorch_default_inputs( model: Model, X: Any, is_train: bool ) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]: - xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train) + shim = cast(PyTorchShim, model.shims[0]) + xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device) converted = convert_recursive(is_xp_array, xp2torch_, X) if isinstance(converted, ArgsKwargs): @@ -181,11 +191,14 @@ def reverse_conversion(dXtorch): def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool): + shim = cast(PyTorchShim, model.shims[0]) X, Ytorch = X_Ytorch Y = convert_recursive(is_torch_array, torch2xp, Ytorch) def reverse_conversion(dY: Any) -> ArgsKwargs: - dYtorch = convert_recursive(is_xp_array, xp2torch, dY) + dYtorch = convert_recursive( + is_xp_array, partial(xp2torch, device=shim.device), dY + ) return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch}) return Y, reverse_conversion @@ -195,6 +208,7 @@ def reverse_conversion(dY: Any) -> ArgsKwargs: def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool): + shim = cast(PyTorchShim, model.shims[0]) size_at_t = Xp.size_at_t lengths = Xp.lengths indices = Xp.indices @@ -203,15 +217,19 @@ def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded: dX = torch2xp(d_inputs.args[0]) return Padded(dX, size_at_t, lengths, indices) # type: ignore - output = ArgsKwargs(args=(xp2torch(Xp.data, requires_grad=True), None), kwargs={}) + output = ArgsKwargs( + args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None), + kwargs={}, + ) return output, convert_from_torch_backward def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train): + shim = cast(PyTorchShim, model.shims[0]) Xp, (Ytorch, _) = inputs_outputs def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs: - dYtorch = xp2torch(dYp.data, requires_grad=True) + dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device) return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch}) Y = cast(Floats3d, torch2xp(Ytorch)) diff --git a/thinc/layers/reduce_first.py b/thinc/layers/reduce_first.py index df7541315..ab72cb5e3 100644 --- a/thinc/layers/reduce_first.py +++ b/thinc/layers/reduce_first.py @@ -1,35 +1,31 @@ -from typing import Callable, Tuple, cast, TypeVar +from typing import Callable, Tuple, cast from ..model import Model from ..config import registry -from ..types import Ragged, ArrayXd +from ..types import Ragged, Floats2d from ..util import ArrayInfo -OutT = TypeVar("OutT", bound=ArrayXd) + +InT = Ragged +OutT = Floats2d @registry.layers("reduce_first.v1") -def reduce_first() -> Model[Ragged, OutT]: +def reduce_first() -> Model[InT, OutT]: """Reduce ragged-formatted sequences to their first element.""" return Model("reduce_first", forward) def forward( - model: Model[Ragged, OutT], Xr: Ragged, is_train: bool -) -> Tuple[OutT, Callable[[OutT], Ragged]]: - starts = model.ops.alloc1i(Xr.lengths.shape[0]) - starts[1:] += Xr.lengths.cumsum()[:-1] - X = Xr.dataXd - Y = cast(OutT, X[starts]) - x_shape = Xr.dataXd.shape - lengths = Xr.lengths + model: Model[InT, OutT], Xr: InT, is_train: bool +) -> Tuple[OutT, Callable[[OutT], InT]]: + Y, starts_ends = model.ops.reduce_first(cast(Floats2d, Xr.data), Xr.lengths) array_info = ArrayInfo.from_array(Y) - def backprop(dY: OutT) -> Ragged: + def backprop(dY: OutT) -> InT: array_info.check_consistency(dY) - dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype) - dX[starts] = dY # type: ignore[assignment] - return Ragged(dX, lengths) + dX = model.ops.backprop_reduce_first(dY, starts_ends) + return Ragged(dX, Xr.lengths) return Y, backprop diff --git a/thinc/layers/reduce_last.py b/thinc/layers/reduce_last.py index e45a65d12..b8194ec2b 100644 --- a/thinc/layers/reduce_last.py +++ b/thinc/layers/reduce_last.py @@ -1,32 +1,29 @@ -from typing import Callable, Tuple, cast, TypeVar +from typing import Callable, Tuple, cast from ..model import Model from ..config import registry -from ..types import Ragged, ArrayXd +from ..types import Ragged, Floats2d from ..util import ArrayInfo -OutT = TypeVar("OutT", bound=ArrayXd) +InT = Ragged +OutT = Floats2d @registry.layers("reduce_last.v1") -def reduce_last() -> Model[Ragged, OutT]: +def reduce_last() -> Model[InT, OutT]: """Reduce ragged-formatted sequences to their last element.""" return Model("reduce_last", forward) def forward( - model: Model[Ragged, OutT], Xr: Ragged, is_train: bool -) -> Tuple[OutT, Callable[[OutT], Ragged]]: - ends = Xr.lengths.cumsum() - 1 - Y = cast(OutT, Xr.dataXd[ends]) - x_shape = Xr.dataXd.shape - lengths = Xr.lengths + model: Model[InT, OutT], Xr: InT, is_train: bool +) -> Tuple[OutT, Callable[[OutT], InT]]: + Y, lasts = model.ops.reduce_last(cast(Floats2d, Xr.data), Xr.lengths) array_info = ArrayInfo.from_array(Y) - def backprop(dY: OutT) -> Ragged: + def backprop(dY: OutT) -> InT: array_info.check_consistency(dY) - dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype) - dX[ends] = dY # type: ignore[assignment] - return Ragged(dX, lengths) + dX = model.ops.backprop_reduce_last(dY, lasts) + return Ragged(dX, Xr.lengths) return Y, backprop diff --git a/thinc/layers/remap_ids.py b/thinc/layers/remap_ids.py index f0f010acb..265b24a9d 100644 --- a/thinc/layers/remap_ids.py +++ b/thinc/layers/remap_ids.py @@ -1,18 +1,23 @@ -from typing import Tuple, Callable, Sequence, Dict, Any +from typing import Tuple, Callable, Sequence, cast +from typing import Dict, Union, Optional, Hashable, Any from ..model import Model from ..config import registry -from ..types import Ints2d, DTypes +from ..types import Ints1d, Ints2d, DTypes +from ..util import is_xp_array, to_numpy -InT = Sequence[Any] +InT = Union[Sequence[Hashable], Ints1d, Ints2d] OutT = Ints2d +InT_v1 = Sequence[Any] +OutT_v1 = Ints2d + @registry.layers("remap_ids.v1") def remap_ids( mapping_table: Dict[Any, int] = {}, default: int = 0, dtype: DTypes = "i" -) -> Model[InT, OutT]: +) -> Model[InT_v1, OutT_v1]: """Remap string or integer inputs using a mapping table, usually as a preprocess before embeddings. The mapping table can be passed in on input, or updated after the layer has been created. The mapping table is stored in @@ -26,7 +31,7 @@ def remap_ids( def forward( - model: Model[InT, OutT], inputs: InT, is_train: bool + model: Model[InT_v1, OutT_v1], inputs: InT_v1, is_train: bool ) -> Tuple[OutT, Callable]: table = model.attrs["mapping_table"] default = model.attrs["default"] @@ -35,7 +40,60 @@ def forward( arr = model.ops.asarray2i(values, dtype=dtype) output = model.ops.reshape2i(arr, -1, 1) - def backprop(dY: OutT) -> InT: + def backprop(dY: OutT_v1) -> InT: return [] return output, backprop + + +@registry.layers("remap_ids.v2") +def remap_ids_v2( + mapping_table: Optional[Union[Dict[int, int], Dict[str, int]]] = None, + default: int = 0, + *, + column: Optional[int] = None +) -> Model[InT, OutT]: + """Remap string or integer inputs using a mapping table, + usually as a preprocessing step before embeddings. + The mapping table can be passed in on input, + or updated after the layer has been created. + The mapping table is stored in the "mapping_table" attribute. + Two dimensional arrays can be provided as input in which case + the 'column' chooses which column to process. This is useful + to work together with FeatureExtractor in spaCy. + """ + return Model( + "remap_ids", + forward_v2, + attrs={"mapping_table": mapping_table, "default": default, "column": column}, + ) + + +def forward_v2( + model: Model[InT, OutT], inputs: InT, is_train: bool +) -> Tuple[OutT, Callable]: + table = model.attrs["mapping_table"] + if table is None: + raise ValueError("'mapping table' not set") + default = model.attrs["default"] + column = model.attrs["column"] + if is_xp_array(inputs): + xp_input = True + if column is not None: + idx = to_numpy(cast(Ints2d, inputs)[:, column]) + else: + idx = to_numpy(inputs) + else: + xp_input = False + idx = inputs + values = [table.get(x, default) for x in idx] + arr = model.ops.asarray2i(values, dtype="i") + output = model.ops.reshape2i(arr, -1, 1) + + def backprop(dY: OutT) -> InT: + if xp_input: + return model.ops.xp.empty(dY.shape) # type: ignore + else: + return [] + + return output, backprop diff --git a/thinc/layers/sigmoid_activation.py b/thinc/layers/sigmoid_activation.py index 8b3982aea..b87261075 100644 --- a/thinc/layers/sigmoid_activation.py +++ b/thinc/layers/sigmoid_activation.py @@ -2,23 +2,23 @@ from ..model import Model from ..config import registry -from ..types import FloatsXd - - -InT = TypeVar("InT", bound=FloatsXd) +from ..types import FloatsXdT @registry.layers("sigmoid_activation.v1") -def sigmoid_activation() -> Model[InT, InT]: +def sigmoid_activation() -> Model[FloatsXdT, FloatsXdT]: return Model("sigmoid_activation", forward) -def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]: +def forward( + model: Model[FloatsXdT, FloatsXdT], X: FloatsXdT, is_train: bool +) -> Tuple[FloatsXdT, Callable]: Y = model.ops.sigmoid(X, inplace=False) - def backprop(dY: InT) -> InT: + def backprop(dY: FloatsXdT) -> FloatsXdT: return cast( - InT, dY * model.ops.dsigmoid(Y, inplace=False) # type:ignore[operator] + FloatsXdT, + dY * model.ops.dsigmoid(Y, inplace=False), # type:ignore[operator] ) return Y, backprop diff --git a/thinc/layers/swish.py b/thinc/layers/swish.py index ea5444b49..a05a0dc72 100644 --- a/thinc/layers/swish.py +++ b/thinc/layers/swish.py @@ -34,8 +34,9 @@ def Swish( return model -def forward(model: Model[Floats2d, Floats2d], - X: Floats2d, is_train: bool) -> Tuple[Floats2d, Callable]: +def forward( + model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool +) -> Tuple[Floats2d, Callable]: W = cast(Floats2d, model.get_param("W")) b = cast(Floats1d, model.get_param("b")) Y_preact = model.ops.affine(X, W, b) diff --git a/thinc/layers/tensorflowwrapper.py b/thinc/layers/tensorflowwrapper.py index dc1a48752..7e166ea50 100644 --- a/thinc/layers/tensorflowwrapper.py +++ b/thinc/layers/tensorflowwrapper.py @@ -7,11 +7,7 @@ from ..util import xp2tensorflow, tensorflow2xp, assert_tensorflow_installed from ..util import is_tensorflow_array, convert_recursive, is_xp_array from ..types import ArrayXd, ArgsKwargs - -try: - import tensorflow as tf -except ImportError: # pragma: no cover - pass +from ..compat import tensorflow as tf InT = TypeVar("InT") OutT = TypeVar("OutT") diff --git a/thinc/layers/with_signpost_interval.py b/thinc/layers/with_signpost_interval.py new file mode 100644 index 000000000..9a468d896 --- /dev/null +++ b/thinc/layers/with_signpost_interval.py @@ -0,0 +1,50 @@ +from typing import Optional, Callable, Any, Tuple, TypeVar + +from ..compat import has_os_signpost, os_signpost +from ..model import Model + + +_ModelT = TypeVar("_ModelT", bound=Model) + + +def with_signpost_interval( + layer: _ModelT, + signposter: "os_signpost.Signposter", + name: Optional[str] = None, +) -> _ModelT: + """Wraps any layer and marks the init, forward and backprop phases using + signpost intervals for macOS Instruments profiling + + By default, the name of the layer is used as the name of the range, + followed by the name of the pass. + """ + if not has_os_signpost: + raise ValueError( + "with_signpost_interval layer requires the 'os_signpost' package" + ) + + name = layer.name if name is None else name + + orig_forward = layer._func + orig_init = layer.init + + def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]: + with signposter.use_interval(f"{name} forward"): + layer_Y, layer_callback = orig_forward(model, X, is_train=is_train) + + def backprop(dY: Any) -> Any: + with signposter.use_interval(f"{name} backprop"): + return layer_callback(dY) + + return layer_Y, backprop + + def init(_model: Model, X: Any, Y: Any) -> Model: + if orig_init is not None: + with signposter.use_interval(f"{name} init"): + return orig_init(layer, X, Y) + else: + return layer + + layer.replace_callbacks(forward, init=init) + + return layer diff --git a/thinc/model.py b/thinc/model.py index 7b8560a8a..08366523e 100644 --- a/thinc/model.py +++ b/thinc/model.py @@ -462,10 +462,35 @@ def copy(self: SelfT) -> SelfT: layers will also be deep-copied. The copy will receive a distinct `model.id` value. """ + return self._copy() + + def _copy( + self: SelfT, seen: Optional[Dict[int, Union["Model", Shim]]] = None + ) -> SelfT: + if seen is None: + seen = {} params = {} for name in self.param_names: params[name] = self.get_param(name) if self.has_param(name) else None + copied_layers: List[Model] = [] + for layer in self.layers: + if id(layer) in seen: + copied_layers.append(cast(Model, seen[id(layer)])) + else: + copied_layer = layer._copy(seen) + seen[id(layer)] = copied_layer + copied_layers.append(copied_layer) + + copied_shims = [] + for shim in self.shims: + if id(shim) in seen: + copied_shims.append(cast(Shim, seen[id(shim)])) + else: + copied_shim = shim.copy() + seen[id(shim)] = copied_shim + copied_shims.append(copied_shim) + copied: Model[InT, OutT] = Model( self.name, self._func, @@ -473,8 +498,8 @@ def copy(self: SelfT) -> SelfT: params=copy.deepcopy(params), dims=copy.deepcopy(self._dims), attrs=copy.deepcopy(self._attrs), - layers=[layer.copy() for layer in self.layers], - shims=[shim.copy() for shim in self.shims], + layers=copied_layers, + shims=copied_shims, ) for name in self.grad_names: copied.set_grad(name, self.get_grad(name).copy()) diff --git a/thinc/optimizers.py b/thinc/optimizers.py index c8e38e84b..f34cd2ff8 100644 --- a/thinc/optimizers.py +++ b/thinc/optimizers.py @@ -279,7 +279,7 @@ def _radam(self, ops, weights, grad, lr_scale, key, nr_upd): # exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg_sq *= beta2 - exp_avg_sq += (1 - beta2) * (gradient_1D ** 2) + exp_avg_sq += (1 - beta2) * (gradient_1D**2) # exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg *= beta1 exp_avg += (1 - beta1) * gradient_1D @@ -338,9 +338,9 @@ def _adam(self, ops, weights, gradient, lr_scale, key, nr_upd): mom2 = self.mom2[key] b1 = self.b1 b2 = self.b2 - fix1 = 1.0 - (b1 ** nr_upd) - fix2 = 1.0 - (b2 ** nr_upd) - lr = self.learn_rate * fix2 ** 0.5 / fix1 + fix1 = 1.0 - (b1**nr_upd) + fix2 = 1.0 - (b2**nr_upd) + lr = self.learn_rate * fix2**0.5 / fix1 eps = self.eps # needs to be 1D going into the adam function weights_1D, gradient_1D, mom1, mom2 = ops.adam( diff --git a/thinc/shims/mxnet.py b/thinc/shims/mxnet.py index 0357be685..3962a2ef5 100644 --- a/thinc/shims/mxnet.py +++ b/thinc/shims/mxnet.py @@ -2,18 +2,12 @@ import srsly import copy -try: - import mxnet.autograd - import mxnet.optimizer - import mxnet as mx -except ImportError: # pragma: no cover - pass - from ..util import mxnet2xp, convert_recursive, make_tempfile, xp2mxnet from ..util import get_array_module from ..optimizers import Optimizer from ..types import ArgsKwargs, FloatsXd from .shim import Shim +from ..compat import mxnet as mx class MXNetShim(Shim): @@ -33,7 +27,7 @@ def predict(self, inputs: ArgsKwargs) -> Any: evaluation mode. """ mx.autograd.set_training(train_mode=False) - with mxnet.autograd.pause(): + with mx.autograd.pause(): outputs = self._model(*inputs.args, **inputs.kwargs) mx.autograd.set_training(train_mode=True) return outputs @@ -50,7 +44,7 @@ def begin_update(self, inputs: ArgsKwargs): def backprop(grads): mx.autograd.set_recording(False) - mxnet.autograd.backward(*grads.args, **grads.kwargs) + mx.autograd.backward(*grads.args, **grads.kwargs) return convert_recursive( lambda x: hasattr(x, "grad"), lambda x: x.grad, inputs ) diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 03a08da83..81a2fe11f 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -4,16 +4,9 @@ import itertools import srsly -try: - import torch.autograd - from torch.cuda import amp - import torch.optim - import torch -except ImportError: # pragma: no cover - pass - from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive -from ..util import has_torch_amp +from ..util import get_torch_default_device +from ..compat import torch from ..backends import get_current_ops, context_pools, CupyOps from ..backends import set_gpu_allocator from ..optimizers import Optimizer @@ -33,6 +26,10 @@ class PyTorchShim(Shim): The gradient scaler to use for mixed-precision training. If this argument is set to "None" and mixed precision is enabled, a gradient scaler with the default configuration is used. + device: + The PyTorch device to run the model on. When this argument is + set to "None", the default device for the currently active Thinc + ops is used. """ def __init__( @@ -42,12 +39,20 @@ def __init__( optimizer: Any = None, mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, + device: Optional["torch.device"] = None, ): super().__init__(model, config, optimizer) + if device is None: + device = get_torch_default_device() + if model is not None: + model.to(device) + if grad_scaler is None: grad_scaler = PyTorchGradScaler(mixed_precision) + grad_scaler.to_(device) + self._grad_scaler = grad_scaler self._mixed_precision = mixed_precision @@ -66,6 +71,14 @@ def __call__(self, inputs, is_train): else: return self.predict(inputs), lambda a: ... + @property + def device(self): + p = next(self._model.parameters(), None) + if p is None: + return get_torch_default_device() + else: + return p.device + def predict(self, inputs: ArgsKwargs) -> Any: """Pass inputs through to the underlying PyTorch model, and return the output. No conversions are performed. The PyTorch model is set into @@ -73,7 +86,7 @@ def predict(self, inputs: ArgsKwargs) -> Any: """ self._model.eval() with torch.no_grad(): - with amp.autocast(self._mixed_precision): + with torch.cuda.amp.autocast(self._mixed_precision): outputs = self._model(*inputs.args, **inputs.kwargs) self._model.train() return outputs @@ -87,7 +100,7 @@ def begin_update(self, inputs: ArgsKwargs): self._model.train() # Note: mixed-precision autocast must not be applied to backprop. - with amp.autocast(self._mixed_precision): + with torch.cuda.amp.autocast(self._mixed_precision): output = self._model(*inputs.args, **inputs.kwargs) def backprop(grads): @@ -134,7 +147,9 @@ def finish_update(self, optimizer: Optimizer): cast(FloatsXd, torch2xp(torch_data.data)), cast(FloatsXd, torch2xp(torch_data.grad)), ) - torch_data.data = xp2torch(param, requires_grad=True) + torch_data.data = xp2torch( + param, requires_grad=True, device=torch_data.device + ) torch_data.grad.zero_() self._grad_scaler.update() @@ -145,7 +160,7 @@ def use_params(self, params): state_dict = {} for k, v in params.items(): if hasattr(k, "startswith") and k.startswith(key_prefix): - state_dict[k.replace(key_prefix, "")] = xp2torch(v) + state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device) if state_dict: backup = {k: v.clone() for k, v in self._model.state_dict().items()} self._model.load_state_dict(state_dict) @@ -172,17 +187,12 @@ def to_bytes(self): return srsly.msgpack_dumps(msg) def from_bytes(self, bytes_data): - ops = get_current_ops() + device = get_torch_default_device() msg = srsly.msgpack_loads(bytes_data) self.cfg = msg["config"] filelike = BytesIO(msg["state"]) filelike.seek(0) - if ops.device_type == "cpu": - map_location = "cpu" - else: # pragma: no cover - device_id = torch.cuda.current_device() - map_location = "cuda:%d" % device_id - self._model.load_state_dict(torch.load(filelike, map_location=map_location)) - self._model.to(map_location) - self._grad_scaler.to_(map_location) + self._model.load_state_dict(torch.load(filelike, map_location=device)) + self._model.to(device) + self._grad_scaler.to_(device) return self diff --git a/thinc/shims/pytorch_grad_scaler.py b/thinc/shims/pytorch_grad_scaler.py index 9fc77209e..8db11bcae 100644 --- a/thinc/shims/pytorch_grad_scaler.py +++ b/thinc/shims/pytorch_grad_scaler.py @@ -1,11 +1,7 @@ from typing import Dict, Iterable, List, Union, cast -from ..util import has_torch_amp, is_torch_array - -try: - import torch -except ImportError: # pragma: no cover - pass +from ..compat import has_torch_amp, torch +from ..util import is_torch_array class PyTorchGradScaler: @@ -55,12 +51,11 @@ def __init__( self._backoff_factor = backoff_factor self._growth_interval = growth_interval - self._found_inf = torch.full((1,), 0.0) self._growth_tracker = torch.full((1,), 0, dtype=torch.int) self._scale = torch.full((1,), init_scale) + self._found_inf = False def to_(self, device): - self._found_inf = self._found_inf.to(device) self._growth_tracker = self._growth_tracker.to(device) self._scale = self._scale.to(device) @@ -136,7 +131,7 @@ def _tensors_per_device(self, tensors): @property def found_inf(self): - return bool(self._found_inf) != 0 + return self._found_inf def unscale(self, tensors): """Unscale the given tensors. Returns True if any of the gradients were infinite.""" @@ -156,9 +151,10 @@ def unscale(self, tensors): device_tensors, found_inf_device, inv_scale_device ) - self._found_inf += found_inf_device.to(self._found_inf.device) + if bool(found_inf_device != 0): + self._found_inf = True - return bool(self._found_inf != 0) + return self._found_inf def update(self): """ @@ -169,14 +165,17 @@ def update(self): if not self._enabled: return + found_inf_device = torch.full( + (1,), 1.0 if self._found_inf else 0.0, device=self._scale.device + ) torch._amp_update_scale_( self._scale, self._growth_tracker, - self._found_inf, + found_inf_device, self._growth_factor, self._backoff_factor, self._growth_interval, ) # Clear infinity found status - self._found_inf = torch.zeros_like(self._found_inf) + self._found_inf = False diff --git a/thinc/shims/shim.py b/thinc/shims/shim.py index 840589ff2..0c246e8d4 100644 --- a/thinc/shims/shim.py +++ b/thinc/shims/shim.py @@ -26,7 +26,8 @@ class Shim: # pragma: no cover def __init__(self, model: Any, config=None, optimizer: Any = None): with Shim.global_id_lock: Shim.global_id += 1 - self.id = Shim.global_id + self.id = Shim.global_id + self.cfg = dict(config) if config is not None else {} self._model = model self._optimizer = optimizer diff --git a/thinc/shims/tensorflow.py b/thinc/shims/tensorflow.py index f226c48df..d630d86f9 100644 --- a/thinc/shims/tensorflow.py +++ b/thinc/shims/tensorflow.py @@ -10,21 +10,8 @@ from ..types import ArgsKwargs, ArrayXd from ..util import get_array_module from .shim import Shim - -try: - import cupy -except ImportError: - cupy = None - -try: - import tensorflow as tf -except ImportError: # pragma: no cover - pass - -try: - import h5py -except ImportError: # pragma: no cover - pass +from ..compat import tensorflow as tf +from ..compat import cupy, h5py keras_model_fns = catalogue.create("thinc", "keras", entry_points=True) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 7ab5496be..04ab7d231 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -8,14 +8,15 @@ from packaging.version import Version from thinc.api import NumpyOps, CupyOps, Ops, get_ops from thinc.api import get_current_ops, use_ops -from thinc.util import has_torch, torch2xp, xp2torch, torch_version, gpu_is_available +from thinc.util import torch2xp, xp2torch +from thinc.compat import has_cupy_gpu, has_torch, torch_version from thinc.api import fix_random_seed from thinc.api import LSTM from thinc.types import Floats2d import inspect from .. import strategies -from ..strategies import ndarrays_of_shape +from ..strategies import arrays_BI, ndarrays_of_shape MAX_EXAMPLES = 10 @@ -25,7 +26,7 @@ BLIS_OPS = NumpyOps(use_blis=True) CPU_OPS = [NUMPY_OPS, VANILLA_OPS] XP_OPS = [NUMPY_OPS] -if CupyOps.xp is not None and gpu_is_available(): +if has_cupy_gpu: XP_OPS.append(CupyOps()) ALL_OPS = XP_OPS + [VANILLA_OPS] @@ -61,7 +62,10 @@ def torch_hard_swish_mobilenet(x): return torch.nn.functional.hardswish(x) def torch_sigmoid(x): - return torch.nn.functional.sigmoid(x) + return torch.sigmoid(x) + + def torch_dish(x): + return 0.5 * x * (x / (1 + x * x).sqrt() + 1) # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L37 def torch_gelu_approx(x): @@ -88,6 +92,7 @@ def torch_gelu(x): ("swish", torch_swish), ("hard_swish", torch_hard_swish), ("hard_swish_mobilenet", torch_hard_swish_mobilenet), + ("dish", torch_dish), ("gelu_approx", torch_gelu_approx), ("gelu", torch_gelu), ("sigmoid", torch_sigmoid), @@ -126,6 +131,22 @@ def test_ops_consistency(op): assert str(p1) == str(p2), attr +@pytest.mark.parametrize("ops", ALL_OPS) +def test_adam_incorrect_inputs(ops): + one = ops.xp.zeros(1, dtype="f") + two = ops.xp.zeros(2, dtype="f") + + ops.adam(one, one, one, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(two, one, one, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(one, two, one, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(one, one, two, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(one, one, one, two, 0.0, 0.0, 0.0, 0.0) + + @pytest.mark.parametrize("ops", ALL_OPS) def test_alloc(ops): float_methods = (ops.alloc1f, ops.alloc2f, ops.alloc3f, ops.alloc4f) @@ -181,6 +202,38 @@ def test_get_dropout_not_empty(ops): assert mask.shape == shape +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +@pytest.mark.parametrize("index_dtype", ["int32", "uint32"]) +def test_gather_add(ops, dtype, index_dtype): + table = ops.xp.arange(12, dtype=dtype).reshape(4, 3) + indices = ops.xp.array([[0, 2], [3, 1], [0, 1]], dtype=index_dtype) + gathered = ops.gather_add(table, indices) + ops.xp.testing.assert_allclose( + gathered, [[6.0, 8.0, 10.0], [12.0, 14.0, 16.0], [3.0, 5.0, 7.0]] + ) + + +@pytest.mark.parametrize("ops", XP_OPS) +@given(table=strategies.arrays_BI()) +def test_gather_add_against_numpy(ops, table): + table = ops.asarray(table) + indices = ops.xp.arange(100, dtype="i").reshape(25, 4) % table.shape[0] + ops.xp.testing.assert_allclose( + ops.gather_add(table, indices), + table[indices].sum(1), + atol=1e-5, + ) + + +@pytest.mark.parametrize("ops", ALL_OPS) +def test_gather_add_oob_raises(ops): + table = ops.xp.arange(12, dtype="f").reshape(4, 3) + indices = ops.xp.array([[0, 2], [3, 1], [5, 1]], dtype="i") + with pytest.raises(IndexError): + ops.gather_add(table, indices) + + @pytest.mark.parametrize("ops", CPU_OPS) def test_seq2col_window_one_small(ops): seq = ops.asarray([[1.0], [3.0], [4.0], [5]], dtype="float32") @@ -199,7 +252,7 @@ def test_seq2col_window_one_small(ops): @given(X=strategies.arrays_BOP()) def test_maxout(ops, dtype, X): X = ops.asarray(X, dtype=dtype) - expected_best = X.max(axis=-1) + expected_best = X.max(axis=-1).astype(dtype) predicted_best, which = ops.maxout(X) assert predicted_best.dtype == dtype ops.xp.testing.assert_allclose( @@ -212,6 +265,7 @@ def test_maxout(ops, dtype, X): ops.xp.testing.assert_allclose( ops.xp.take_along_axis(X, ops.xp.expand_dims(which, -1), axis=-1), ops.xp.expand_dims(expected_best, -1), + atol=1e-10, ) @@ -574,9 +628,7 @@ def test_backprop_seq2col_window_two(ops, dtype): ops.xp.testing.assert_allclose(seq, expected, atol=0.001, rtol=0.001) -@pytest.mark.skipif( - CupyOps.xp is None or not gpu_is_available(), reason="needs GPU/CuPy" -) +@pytest.mark.skipif(not has_cupy_gpu, reason="needs GPU/CuPy") @pytest.mark.parametrize("nW", [1, 2]) def test_large_seq2col_gpu_against_cpu(nW): cupy_ops = CupyOps() @@ -598,9 +650,7 @@ def test_large_seq2col_gpu_against_cpu(nW): assert_allclose(cols, cols_gpu.get()) -@pytest.mark.skipif( - CupyOps.xp is None or not gpu_is_available(), reason="needs GPU/CuPy" -) +@pytest.mark.skipif(not has_cupy_gpu, reason="needs GPU/CuPy") @pytest.mark.parametrize("nW", [1, 2]) def test_large_backprop_seq2col_gpu_against_cpu(nW): cupy_ops = CupyOps() @@ -771,6 +821,68 @@ def test_backprop_fails_with_incorrect_length(ops, dtype): ) +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +def test_reduce_first(ops, dtype): + X = ops.asarray2f( + [[1.0, 6.0], [2.0, 7.0], [3.0, 8.0], [4.0, 9.0], [5.0, 10.0]], dtype=dtype + ) + lengths = ops.asarray1i([3, 2]) + Y, starts_ends = ops.reduce_first(X, lengths) + ops.xp.testing.assert_array_equal(starts_ends, ops.asarray1i([0, 3, 5])) + ops.xp.testing.assert_allclose(Y, [[1.0, 6.0], [4.0, 9.0]]) + + lengths = ops.asarray1i([3, 0, 2]) + with pytest.raises(ValueError, match=r"all sequence lengths must be >= 0"): + ops.reduce_last(X, lengths) + + lengths = ops.asarray1i([3, 2, 1]) + with pytest.raises(IndexError, match=r"lengths must sum up to the number of rows"): + ops.reduce_last(X, lengths) + + +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +def test_backprop_reduce_first(ops, dtype): + dY = ops.asarray2f([[1.0, 3.0], [2.0, 4.0]], dtype=dtype) + starts_ends = ops.asarray1i([0, 3, 5]) + dX = ops.backprop_reduce_first(dY, starts_ends) + ops.xp.testing.assert_allclose( + dX, [[1.0, 3.0], [0.0, 0.0], [0.0, 0.0], [2.0, 4.0], [0.0, 0.0]] + ) + + +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +def test_reduce_last(ops, dtype): + X = ops.asarray2f( + [[1.0, 6.0], [2.0, 7.0], [3.0, 8.0], [4.0, 9.0], [5.0, 10.0]], dtype=dtype + ) + lengths = ops.asarray1i([3, 2]) + Y, lasts = ops.reduce_last(X, lengths) + ops.xp.testing.assert_array_equal(lasts, ops.asarray1i([2, 4])) + ops.xp.testing.assert_allclose(Y, [[3.0, 8.0], [5.0, 10.0]]) + + lengths = ops.asarray1i([3, 0, 2]) + with pytest.raises(ValueError, match=r"all sequence lengths must be >= 0"): + ops.reduce_last(X, lengths) + + lengths = ops.asarray1i([3, 2, 1]) + with pytest.raises(IndexError, match=r"lengths must sum up to the number of rows"): + ops.reduce_last(X, lengths) + + +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +def test_backprop_reduce_last(ops, dtype): + dY = ops.asarray2f([[1.0, 3.0], [2.0, 4.0]], dtype=dtype) + lasts = ops.asarray1i([2, 4]) + dX = ops.backprop_reduce_last(dY, lasts) + ops.xp.testing.assert_allclose( + dX, [[0.0, 0.0], [0.0, 0.0], [1.0, 3.0], [0.0, 0.0], [2.0, 4.0]] + ) + + @pytest.mark.parametrize("ops", ALL_OPS) @pytest.mark.parametrize("dtype", FLOAT_TYPES) def test_reduce_max_sm(ops, dtype): @@ -935,6 +1047,7 @@ def test_mish(ops, X): "op", [ "backprop_clipped_linear", + "backprop_dish", "backprop_gelu", "backprop_gelu_approx", "backprop_hard_sigmoid", @@ -1052,6 +1165,16 @@ def test_gelu_approx(ops, X): assert not ops.xp.isnan(Y).any() +@pytest.mark.parametrize("ops", ALL_OPS) +@settings(max_examples=MAX_EXAMPLES, deadline=None) +@given(X=strategies.arrays_BI()) +def test_dish(ops, X): + X = ops.asarray(X) + Y = ops.dish(X) + assert Y.shape == X.shape + assert not ops.xp.isnan(Y).any() + + @pytest.mark.parametrize("ops", ALL_OPS) @settings(max_examples=MAX_EXAMPLES, deadline=None) @given(X=strategies.arrays_BI()) @@ -1242,28 +1365,31 @@ def test_ngrams(): @pytest.mark.parametrize("dtype", ["float32", "float64"]) @pytest.mark.parametrize("torch_func", TORCH_FUNCS) @settings(max_examples=MAX_EXAMPLES, deadline=None) -@given(x=strategies.floats(min_value=-30, max_value=30)) -def test_compare_activations_to_torch(ops, dtype, x, torch_func): +@given( + x=strategies.floats(min_value=-30, max_value=30), + dY=strategies.floats(min_value=-1, max_value=1), +) +def test_compare_activations_to_torch(ops, dtype, x, dY, torch_func): import torch - def cast_torch(scalar: float): - return torch.tensor([scalar], requires_grad=True) - func_name, pytorch_func = torch_func forward = getattr(ops, func_name) backward = getattr(ops, "backprop_" + func_name) # The tolerance of isclose is set to 1e-06 instead of # the default 1e-08 due to the GELU x_thinc = ops.asarray([x], dtype=dtype) - x_torch = cast_torch(x) + x_torch = xp2torch(x_thinc, requires_grad=True) y = pytorch_func(x_torch) y_thinc = forward(x_thinc) y.backward() assert x_thinc.dtype == y_thinc.dtype - assert ops.xp.isclose(y_thinc, forward(x_thinc, inplace=True), atol=1e-06) - assert ops.xp.isclose(y_thinc, y.detach().numpy(), atol=1e-06) + assert y_thinc is not x_thinc + y_think_inplace = forward(x_thinc, inplace=True) + assert y_think_inplace is x_thinc + assert ops.xp.isclose(y_thinc, y_think_inplace, atol=1e-06) + assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-05) x_thinc = ops.asarray([x], dtype=dtype) - dY_thinc = ops.asarray([1.0], dtype=dtype) + dY_thinc = ops.asarray([dY], dtype=dtype) dY_thinc_inplace = dY_thinc.copy() s = inspect.signature(backward) @@ -1272,11 +1398,13 @@ def cast_torch(scalar: float): if params == {"dY", "X", "Y"}: dx_thinc = backward(dY_thinc, Y=y_thinc, X=x_thinc) assert dx_thinc.dtype == x_thinc.dtype - assert ops.xp.isclose( - dx_thinc, - backward(dY=dY_thinc_inplace, Y=y_thinc, X=x_thinc, inplace=True), + assert dx_thinc is not dY_thinc + dx_thinc_inplace = backward( + dY=dY_thinc_inplace, Y=y_thinc, X=x_thinc, inplace=True ) - assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06) + assert dx_thinc_inplace is dY_thinc_inplace + assert ops.xp.isclose(dx_thinc, dx_thinc_inplace) + assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06) elif params == {"Y", "dY"}: dx_thinc = backward(dY_thinc, Y=y_thinc) assert dx_thinc.dtype == x_thinc.dtype @@ -1284,7 +1412,7 @@ def cast_torch(scalar: float): dx_thinc, backward(dY=dY_thinc_inplace, Y=y_thinc, inplace=True), ) - assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06) + assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06) elif params == {"dY", "X"}: dx_thinc = backward(dY_thinc, X=x_thinc) assert dx_thinc.dtype == x_thinc.dtype @@ -1292,7 +1420,7 @@ def cast_torch(scalar: float): dx_thinc, backward(dY=dY_thinc_inplace, X=x_thinc, inplace=True) ) assert ops.xp.isclose( - x_torch.grad.item(), float(backward(dY_thinc, X=x_thinc)), atol=1e-06 + x_torch.grad.item() * dY, float(backward(dY_thinc, X=x_thinc)), atol=1e-06 ) else: raise NotImplementedError( diff --git a/thinc/tests/conftest.py b/thinc/tests/conftest.py index 239628b43..19b5137d3 100644 --- a/thinc/tests/conftest.py +++ b/thinc/tests/conftest.py @@ -1,8 +1,37 @@ import pytest +from hypothesis import settings + +# Functionally disable deadline settings for tests +# to prevent spurious test failures in CI builds. +settings.register_profile("no_deadlines", deadline=2 * 60 * 1000) # in ms +settings.load_profile("no_deadlines") + + +def pytest_sessionstart(session): + # If Tensorflow is installed, attempt to enable memory growth + # to prevent it from allocating all of the GPU's free memory + # to its internal memory pool(s). + try: + import tensorflow as tf + + physical_devices = tf.config.list_physical_devices("GPU") + for device in physical_devices: + try: + tf.config.experimental.set_memory_growth(device, True) + except: + # Invalid device or cannot modify virtual devices once initialized. + print(f"failed to enable Tensorflow memory growth on {device}") + except ImportError: + pass def pytest_addoption(parser): - parser.addoption("--slow", action="store_true", help="include slow tests") + try: + parser.addoption("--slow", action="store_true", help="include slow tests") + # Options are already added, e.g. if conftest is copied in a build pipeline + # and runs twice + except ValueError: + pass def pytest_runtest_setup(item): diff --git a/thinc/tests/layers/test_basic_tagger.py b/thinc/tests/layers/test_basic_tagger.py index e9075fec2..3046c1b04 100644 --- a/thinc/tests/layers/test_basic_tagger.py +++ b/thinc/tests/layers/test_basic_tagger.py @@ -8,13 +8,14 @@ def ancora(): pytest.importorskip("ml_datasets") import ml_datasets + return ml_datasets.ud_ancora_pos_tags() def create_embed_relu_relu_softmax(depth, width, vector_length): with Model.define_operators({">>": chain}): model = strings2arrays() >> with_array( - HashEmbed(width, vector_length) + HashEmbed(width, vector_length, column=0) >> expand_window(window_size=1) >> Relu(width, width * 3) >> Relu(width, width) diff --git a/thinc/tests/layers/test_combinators.py b/thinc/tests/layers/test_combinators.py index ed9a2992a..ea5583108 100644 --- a/thinc/tests/layers/test_combinators.py +++ b/thinc/tests/layers/test_combinators.py @@ -271,10 +271,7 @@ def test_concatenate(): def test_map_list(): nI = 4 nO = 9 - Xs = [ - numpy.zeros((6, nI), dtype="f"), - numpy.ones((3, nI), dtype="f") - ] + Xs = [numpy.zeros((6, nI), dtype="f"), numpy.ones((3, nI), dtype="f")] Y_shapes = [(x.shape[0], nO) for x in Xs] model = map_list(Linear()) model.initialize(X=Xs, Y=[numpy.zeros(shape, dtype="f") for shape in Y_shapes]) diff --git a/thinc/tests/layers/test_layers_api.py b/thinc/tests/layers/test_layers_api.py index e612cf318..c6e40c3fa 100644 --- a/thinc/tests/layers/test_layers_api.py +++ b/thinc/tests/layers/test_layers_api.py @@ -5,7 +5,7 @@ from thinc.backends import NumpyOps from thinc.util import data_validation, get_width from thinc.types import Ragged, Padded, Array2d, Floats2d, FloatsXd, Shape -from thinc.util import has_torch +from thinc.compat import has_torch import numpy import pytest @@ -57,6 +57,8 @@ def assert_data_match(Y, out_data): TEST_CASES_SUMMABLE = [ # Array to array + ("Dish.v1", {}, array2d, array2d), + ("Dish.v1", {"nO": 4, "nI": 4}, array2d, array2d), ("Dropout.v1", {}, array2d, array2d), ("LayerNorm.v1", {}, array2d, array2d), ("Linear.v1", {}, array2d, array2d), @@ -126,7 +128,8 @@ def assert_data_match(Y, out_data): # ("CauchySimilarity.v1", {}, (array2d, array2d), array1d), ("ParametricAttention.v1", {}, ragged, ragged), ("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d), - ("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint) + ("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint), + ("remap_ids.v2", {"mapping_table": {}, "column": 1}, numpy.array([[1, 2, 3], [4, 5, 6]]).T, array2dint) # fmt: on ] diff --git a/thinc/tests/layers/test_linear.py b/thinc/tests/layers/test_linear.py index ef00d77b6..2362b556b 100644 --- a/thinc/tests/layers/test_linear.py +++ b/thinc/tests/layers/test_linear.py @@ -87,6 +87,7 @@ def test_predict_small(W_b_input): @given(arrays_OI_O_BI(max_batch=20, max_out=30, max_in=30)) +@settings(deadline=None) def test_predict_extensive(W_b_input): W, b, input_ = W_b_input nr_out, nr_in = W.shape diff --git a/thinc/tests/layers/test_lstm.py b/thinc/tests/layers/test_lstm.py index 850f569b8..208ffb58b 100644 --- a/thinc/tests/layers/test_lstm.py +++ b/thinc/tests/layers/test_lstm.py @@ -2,7 +2,7 @@ import timeit from thinc.api import NumpyOps, LSTM, PyTorchLSTM, with_padded, fix_random_seed from thinc.api import Ops -from thinc.util import has_torch +from thinc.compat import has_torch import pytest diff --git a/thinc/tests/layers/test_mnist.py b/thinc/tests/layers/test_mnist.py index 0ed0dfd64..321de3a0f 100644 --- a/thinc/tests/layers/test_mnist.py +++ b/thinc/tests/layers/test_mnist.py @@ -1,13 +1,15 @@ import pytest from thinc.api import Relu, Softmax, chain, clone, Adam from thinc.api import PyTorchWrapper, TensorFlowWrapper -from thinc.util import has_torch, has_tensorflow +from thinc.api import get_current_ops +from thinc.compat import has_torch, has_tensorflow @pytest.fixture(scope="module") def mnist(limit=5000): pytest.importorskip("ml_datasets") import ml_datasets + (train_X, train_Y), (dev_X, dev_Y) = ml_datasets.mnist() return (train_X[:limit], train_Y[:limit]), (dev_X[:limit], dev_Y[:limit]) @@ -79,9 +81,14 @@ def test_small_end_to_end(width, nb_epoch, min_score, create_model, mnist): optimizer = Adam(0.001) losses = [] scores = [] + ops = get_current_ops() + for i in range(nb_epoch): for X, Y in model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True): Yh, backprop = model.begin_update(X) + # Ensure that the tensor is type-compatible with the current backend. + Yh = ops.asarray(Yh) + backprop(Yh - Y) model.finish_update(optimizer) losses.append(((Yh - Y) ** 2).sum()) @@ -89,6 +96,8 @@ def test_small_end_to_end(width, nb_epoch, min_score, create_model, mnist): total = 0 for X, Y in model.ops.multibatch(batch_size, dev_X, dev_Y): Yh = model.predict(X) + Yh = ops.asarray(Yh) + correct += (Yh.argmax(axis=1) == Y.argmax(axis=1)).sum() total += Yh.shape[0] score = correct / total diff --git a/thinc/tests/layers/test_mxnet_wrapper.py b/thinc/tests/layers/test_mxnet_wrapper.py index 438b28f33..b954a8ec5 100644 --- a/thinc/tests/layers/test_mxnet_wrapper.py +++ b/thinc/tests/layers/test_mxnet_wrapper.py @@ -5,7 +5,8 @@ from thinc.api import Adam, ArgsKwargs, Model, Ops, MXNetWrapper from thinc.api import get_current_ops, mxnet2xp, xp2mxnet from thinc.types import Array2d, Array1d, IntsXd -from thinc.util import has_cupy, has_mxnet, to_categorical +from thinc.compat import has_cupy_gpu, has_mxnet +from thinc.util import to_categorical from ..util import check_input_converters, make_tempdir @@ -33,7 +34,7 @@ def answer() -> int: @pytest.fixture def X(input_size: int) -> Array2d: ops: Ops = get_current_ops() - return ops.alloc(shape=(1, input_size)) + return cast(Array2d, ops.alloc(shape=(1, input_size))) @pytest.fixture @@ -157,7 +158,8 @@ def test_mxnet_wrapper_to_cpu(mx_model, X: Array2d): model.to_cpu() -@pytest.mark.skipif(not has_mxnet or not has_cupy, reason="needs MXNet") +@pytest.mark.skipif(not has_mxnet, reason="needs MXNet") +@pytest.mark.skipif(not has_cupy_gpu, reason="needs GPU/cupy") def test_mxnet_wrapper_to_gpu(model: Model[Array2d, Array2d], X: Array2d): model.predict(X) model.to_gpu(0) diff --git a/thinc/tests/layers/test_pytorch_wrapper.py b/thinc/tests/layers/test_pytorch_wrapper.py index ce3b6ae8d..d2eeaeb97 100644 --- a/thinc/tests/layers/test_pytorch_wrapper.py +++ b/thinc/tests/layers/test_pytorch_wrapper.py @@ -1,20 +1,37 @@ from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2 from thinc.api import xp2torch, torch2xp, ArgsKwargs, use_ops from thinc.api import chain, get_current_ops, Relu +from thinc.api import CupyOps, MPSOps, NumpyOps from thinc.backends import context_pools from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler -from thinc.util import has_torch, has_torch_amp, has_torch_gpu +from thinc.compat import has_torch, has_torch_amp +from thinc.compat import has_cupy_gpu, has_torch_mps_gpu import numpy import pytest +from thinc.util import get_torch_default_device from ..util import make_tempdir, check_input_converters +XP_OPS = [NumpyOps()] +if has_cupy_gpu: + XP_OPS.append(CupyOps()) +if has_torch_mps_gpu: + XP_OPS.append(MPSOps()) + + if has_torch_amp: TORCH_MIXED_PRECISION = [False, True] else: TORCH_MIXED_PRECISION = [False] +XP_OPS_MIXED = [ + (ops, mixed) + for ops in XP_OPS + for mixed in TORCH_MIXED_PRECISION + if not mixed or isinstance(ops, CupyOps) +] + def check_learns_zero_output(model, sgd, X, Y): """Check we can learn to output a zero vector""" @@ -63,22 +80,25 @@ def test_pytorch_wrapper(nN, nI, nO): assert isinstance(model.predict(X), numpy.ndarray) -@pytest.mark.skipif(not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU") +@pytest.mark.skipif(not has_torch, reason="needs PyTorch") +@pytest.mark.parametrize("ops_mixed", XP_OPS_MIXED) @pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)]) -@pytest.mark.parametrize("mixed_precision", TORCH_MIXED_PRECISION) -def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision): +def test_pytorch_wrapper_thinc_input(ops_mixed, nN, nI, nO): import torch.nn - with use_ops("cupy"): + ops, mixed_precision = ops_mixed + + with use_ops(ops.name): ops = get_current_ops() pytorch_layer = torch.nn.Linear(nO, nO) # Initialize with large weights to trigger overflow of FP16 in # mixed-precision training. torch.nn.init.uniform_(pytorch_layer.weight, 9.0, 11.0) + device = get_torch_default_device() model = chain( Relu(), PyTorchWrapper_v2( - pytorch_layer.cuda(), + pytorch_layer.to(device), mixed_precision=mixed_precision, grad_scaler=PyTorchGradScaler( enabled=mixed_precision, init_scale=2.0**16 @@ -86,7 +106,8 @@ def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision): ).initialize(), ) # pytorch allocator is set in PyTorchShim - assert "pytorch" in context_pools.get() + if isinstance(ops, CupyOps): + assert "pytorch" in context_pools.get() sgd = SGD(0.001) X = ops.xp.zeros((nN, nI), dtype="f") X += ops.xp.random.uniform(size=X.size).reshape(X.shape) diff --git a/thinc/tests/layers/test_reduce.py b/thinc/tests/layers/test_reduce.py index ba829f779..d26065c4a 100644 --- a/thinc/tests/layers/test_reduce.py +++ b/thinc/tests/layers/test_reduce.py @@ -92,6 +92,7 @@ def test_reduce_mean(Xs): dX = backprop(Y) assert dX.dataXd.shape == X.dataXd.shape + def test_reduce_sum(Xs): model = reduce_sum() lengths = model.ops.asarray([x.shape[0] for x in Xs], dtype="i") @@ -107,6 +108,7 @@ def test_reduce_sum(Xs): dX = backprop(Y) assert dX.dataXd.shape == X.dataXd.shape + def test_size_mismatch(Xs): for reduce in [reduce_first, reduce_last, reduce_max, reduce_mean, reduce_sum]: model = reduce() diff --git a/thinc/tests/layers/test_tensorflow_wrapper.py b/thinc/tests/layers/test_tensorflow_wrapper.py index 1c10b8242..c1b85da3b 100644 --- a/thinc/tests/layers/test_tensorflow_wrapper.py +++ b/thinc/tests/layers/test_tensorflow_wrapper.py @@ -2,7 +2,8 @@ import pytest from thinc.api import Adam, ArgsKwargs, Linear, Model, TensorFlowWrapper from thinc.api import get_current_ops, keras_subclass, tensorflow2xp, xp2tensorflow -from thinc.util import gpu_is_available, has_tensorflow, to_categorical +from thinc.util import to_categorical +from thinc.compat import has_cupy_gpu, has_tensorflow from ..util import check_input_converters, make_tempdir @@ -358,7 +359,7 @@ def test_tensorflow_wrapper_to_cpu(tf_model): @pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow") -@pytest.mark.skipif(not gpu_is_available(), reason="needs GPU/cupy") +@pytest.mark.skipif(not has_cupy_gpu, reason="needs GPU/cupy") def test_tensorflow_wrapper_to_gpu(model, X): model.to_gpu(0) diff --git a/thinc/tests/layers/test_uniqued.py b/thinc/tests/layers/test_uniqued.py index 41d98ca16..9cb207ca5 100644 --- a/thinc/tests/layers/test_uniqued.py +++ b/thinc/tests/layers/test_uniqued.py @@ -3,7 +3,7 @@ from thinc.layers import Embed from thinc.layers.uniqued import uniqued from numpy.testing import assert_allclose -from hypothesis import given +from hypothesis import given, settings from hypothesis.strategies import integers, lists, composite ROWS = 10 @@ -46,6 +46,7 @@ def test_uniqued_calls_init(): @given(X=lists_of_integers(lo=0, hi=ROWS - 1)) +@settings(deadline=None) def test_uniqued_doesnt_change_result(model, X): umodel = uniqued(model, column=model.attrs["column"]).initialize() Y, bp_Y = model(X, is_train=True) diff --git a/thinc/tests/layers/test_with_transforms.py b/thinc/tests/layers/test_with_transforms.py index a01e20567..c23db1463 100644 --- a/thinc/tests/layers/test_with_transforms.py +++ b/thinc/tests/layers/test_with_transforms.py @@ -26,8 +26,8 @@ def list_input(shapes): for i, x in enumerate(data): # Give values that make it easy to see where rows or columns mismatch. x += i * 100 - x += numpy.arange(x.shape[0]).reshape((-1, 1)) * 10 - x += numpy.arange(x.shape[1]).reshape((1, -1)) + x += numpy.arange(x.shape[0]).reshape((-1, 1)) * 10 + x += numpy.arange(x.shape[1]).reshape((1, -1)) return data @@ -68,8 +68,10 @@ def noop_models(): with_array(noop()), with_array2d(noop()), with_list(noop()), - with_ragged(noop()) + with_ragged(noop()), ] + + # As an example operation, lets just trim the last dimension. That # should catch stuff that confuses the input and output. @@ -180,14 +182,14 @@ def test_noop_transforms(noop_models, ragged_input, padded_input, list_input): d_ragged = Ragged(ragged_input.data + 1, ragged_input.lengths) d_padded = padded_input.copy() d_padded.data += 1 - d_list = [dx+1 for dx in list_input] + d_list = [dx + 1 for dx in list_input] for model in noop_models: print(model.name) check_transform_doesnt_change_noop_values(model, padded_input, d_padded) check_transform_doesnt_change_noop_values(model, list_input, d_list) check_transform_doesnt_change_noop_values(model, ragged_input, d_ragged) - + def test_with_array_initialize(ragged_input, padded_input, list_input, array_input): for inputs in (ragged_input, padded_input, list_input, array_input): check_initialize(get_array_model(), inputs) diff --git a/thinc/tests/model/test_model.py b/thinc/tests/model/test_model.py index c552e3e4c..733b3329f 100644 --- a/thinc/tests/model/test_model.py +++ b/thinc/tests/model/test_model.py @@ -4,9 +4,9 @@ import time from thinc.api import Adam, CupyOps, Dropout, Linear, Model, Relu from thinc.api import Shim, Softmax, chain, change_attr_values -from thinc.api import concatenate, has_cupy, set_dropout_rate +from thinc.api import concatenate, set_dropout_rate from thinc.api import use_ops, with_debug, wrap_model_recursive -from thinc.util import gpu_is_available +from thinc.compat import has_cupy_gpu import numpy from ..util import make_tempdir @@ -349,10 +349,10 @@ def test_all_operators(op): with pytest.raises(TypeError): value = m1 % m2 if op == "**": - value = m1 ** m2 + value = m1**m2 else: with pytest.raises(TypeError): - value = m1 ** m2 + value = m1**m2 if op == "<<": value = m1 << m2 else: @@ -404,15 +404,12 @@ def get_model_id(id_list, index): assert len(list_of_ids) == len(list(set(list_of_ids))) +@pytest.mark.skipif(not has_cupy_gpu, reason="needs CuPy GPU") def test_model_gpu(): pytest.importorskip("ml_datasets") import ml_datasets - ops = "cpu" - if has_cupy and gpu_is_available(): - ops = "cupy" - - with use_ops(ops): + with use_ops("cupy"): n_hidden = 32 dropout = 0.2 (train_X, train_Y), (dev_X, dev_Y) = ml_datasets.mnist() @@ -614,3 +611,44 @@ def test_walk_bfs_post_order_fails(): relu = Relu(5) with pytest.raises(ValueError, match="Invalid order"): relu.walk(order="dfs_post_order") + + +def test_model_copy_with_loop(): + class MyShim(Shim): + name = "testshim" + + def to_bytes(self): + return test_replace_node_with_indirect_node_ref + + def from_bytes(self, bytes): + pass + + model_a = create_model("a") + working_shim = MyShim(None) + layer = Model( + "test", + lambda X: (X, lambda dY: dY), + dims={"nI": 5, "nO": 5}, + params={"W": numpy.zeros((10,)), "b": None}, + refs={"a": model_a, "b": None}, + attrs={"foo": "bar"}, + shims=[working_shim], + layers=[model_a, model_a], + ) + layer2 = Model( + "test2", + lambda X: (X, lambda dY: dY), + dims={"nI": 5, "nO": 5}, + params={"W": numpy.zeros((10,)), "b": None}, + refs={"a": model_a, "b": None}, + attrs={"foo": "bar"}, + shims=[working_shim], + layers=[model_a, model_a], + ) + relu = Relu(5) + model = chain(layer, relu, layer, layer2) + model2 = model.copy() + model.from_dict(model2.to_dict()) + assert model2.name == "test>>relu>>test>>test2" + assert model2.layers[0] == model2.layers[2] + assert id(model2.layers[0].shims[0]) == id(model2.layers[3].shims[0]) diff --git a/thinc/tests/mypy/test_mypy.py b/thinc/tests/mypy/test_mypy.py index 287043578..e03d1c874 100644 --- a/thinc/tests/mypy/test_mypy.py +++ b/thinc/tests/mypy/test_mypy.py @@ -23,6 +23,7 @@ def test_mypy_results( ): pytest.importorskip("mypy") from mypy import api as mypy_api + os.chdir(tmpdir) root_dir = Path(__file__).parent thinc_root_dir = Path(__file__).parent.parent.parent.parent diff --git a/thinc/tests/regression/issue519/program.py b/thinc/tests/regression/issue519/program.py index 2ad28d88b..b3e6dc9ba 100644 --- a/thinc/tests/regression/issue519/program.py +++ b/thinc/tests/regression/issue519/program.py @@ -5,16 +5,12 @@ n_hidden = 32 dropout = 0.2 -model1 = chain( - Relu(nO=n_hidden, dropout=dropout), - Relu(nO=n_hidden, dropout=dropout), - Softmax() +model1: Model[Floats2d, Floats2d] = chain( + Relu(nO=n_hidden, dropout=dropout), Relu(nO=n_hidden, dropout=dropout), Softmax() ) -model2 = chain( - Relu(nO=n_hidden, dropout=dropout), - Relu(nO=n_hidden, dropout=dropout), - Softmax() +model2: Model[Floats2d, Floats2d] = chain( + Relu(nO=n_hidden, dropout=dropout), Relu(nO=n_hidden, dropout=dropout), Softmax() ) model3: Model[Floats2d, Floats2d] = concatenate(*[model1, model2]) diff --git a/thinc/tests/regression/issue519/test_issue519.py b/thinc/tests/regression/issue519/test_issue519.py index 4a26f80e0..02601f0d7 100644 --- a/thinc/tests/regression/issue519/test_issue519.py +++ b/thinc/tests/regression/issue519/test_issue519.py @@ -1,4 +1,6 @@ import subprocess +import sys + try: import importlib.resources as importlib_resources except ImportError: @@ -16,10 +18,12 @@ def test_issue519(): This test can take up to 45 seconds, and is thus marked as slow. """ # Determine the name of the parent module (which contains the test program) - parent_module_name = __name__[:__name__.rfind(".")] + parent_module_name = __name__[: __name__.rfind(".")] # Load test program that calls a Thinc API with variadic arguments program_text = importlib_resources.read_text(parent_module_name, "program.py") # Ask Mypy to type-check the loaded program text - subprocess.run(["mypy", "--command", program_text], check=True) + subprocess.run( + [sys.executable, "-m", "mypy", "--command", program_text], check=True + ) diff --git a/thinc/tests/regression/test_issue564.py b/thinc/tests/regression/test_issue564.py index 895d25cfa..94ecc6e63 100644 --- a/thinc/tests/regression/test_issue564.py +++ b/thinc/tests/regression/test_issue564.py @@ -1,11 +1,11 @@ import pytest from thinc.api import CupyOps -from thinc.util import has_torch, has_torch_gpu +from thinc.compat import has_torch, has_torch_cuda_gpu @pytest.mark.skipif(not has_torch, reason="needs PyTorch") -@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU") def test_issue564(): import torch diff --git a/thinc/tests/shims/test_pytorch_grad_scaler.py b/thinc/tests/shims/test_pytorch_grad_scaler.py index 26eab9291..2ab0fa738 100644 --- a/thinc/tests/shims/test_pytorch_grad_scaler.py +++ b/thinc/tests/shims/test_pytorch_grad_scaler.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis.strategies import lists, one_of, tuples -from thinc.util import has_torch, has_torch_amp, has_torch_gpu +from thinc.compat import has_torch, has_torch_amp, has_torch_cuda_gpu, torch from thinc.util import is_torch_array from thinc.api import PyTorchGradScaler @@ -10,19 +10,11 @@ def tensors(): - # This function is not used without Torch + CUDA, - # but we have to do some wrapping to avoid import - # failures. - try: - import torch - - return ndarrays().map(lambda a: torch.tensor(a).cuda()) - except ImportError: - pass + return ndarrays().map(lambda a: torch.tensor(a).cuda()) @pytest.mark.skipif(not has_torch, reason="needs PyTorch") -@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU") @pytest.mark.skipif( not has_torch_amp, reason="requires PyTorch with mixed-precision support" ) @@ -45,7 +37,7 @@ def test_scale_random_inputs(X): @pytest.mark.skipif(not has_torch, reason="needs PyTorch") -@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU") @pytest.mark.skipif( not has_torch_amp, reason="requires PyTorch with mixed-precision support" ) @@ -97,6 +89,7 @@ def test_grad_scaler(): ) def test_raises_on_old_pytorch(): import torch + scaler = PyTorchGradScaler(enabled=True) with pytest.raises(ValueError, match=r"not supported.*1.9.0"): scaler.scale([torch.tensor([1.0], device="cpu")]) diff --git a/thinc/tests/test_config.py b/thinc/tests/test_config.py index ddd05ca96..0dceadfc4 100644 --- a/thinc/tests/test_config.py +++ b/thinc/tests/test_config.py @@ -135,361 +135,6 @@ def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str: worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False} -def test_validate_simple_config(): - simple_config = {"hello": 1, "world": 2} - f, _, v = my_registry._fill(simple_config, HelloIntsSchema) - assert f == simple_config - assert v == simple_config - - -def test_invalidate_simple_config(): - invalid_config = {"hello": 1, "world": "hi!"} - with pytest.raises(ConfigValidationError) as exc_info: - my_registry._fill(invalid_config, HelloIntsSchema) - error = exc_info.value - assert len(error.errors) == 1 - assert "type_error.integer" in error.error_types - - -def test_invalidate_extra_args(): - invalid_config = {"hello": 1, "world": 2, "extra": 3} - with pytest.raises(ConfigValidationError): - my_registry._fill(invalid_config, HelloIntsSchema) - - -def test_fill_defaults_simple_config(): - valid_config = {"required": 1} - filled, _, v = my_registry._fill(valid_config, DefaultsSchema) - assert filled["required"] == 1 - assert filled["optional"] == "default value" - invalid_config = {"optional": "some value"} - with pytest.raises(ConfigValidationError): - my_registry._fill(invalid_config, DefaultsSchema) - - -def test_fill_recursive_config(): - valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}} - filled, _, validation = my_registry._fill(valid_config, ComplexSchema) - assert filled["outer_req"] == 1 - assert filled["outer_opt"] == "default value" - assert filled["level2_req"]["hello"] == 4 - assert filled["level2_req"]["world"] == 7 - assert filled["level2_opt"]["required"] == 1 - assert filled["level2_opt"]["optional"] == "default value" - - -def test_is_promise(): - assert my_registry.is_promise(good_catsie) - assert not my_registry.is_promise({"hello": "world"}) - assert not my_registry.is_promise(1) - invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} - assert my_registry.is_promise(invalid) - - -def test_get_constructor(): - my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1") - - -def test_parse_args(): - args, kwargs = my_registry.parse_args(bad_catsie) - assert args == [] - assert kwargs == {"evil": True, "cute": True} - - -def test_make_promise_schema(): - schema = my_registry.make_promise_schema(good_catsie) - assert "evil" in schema.__fields__ - assert "cute" in schema.__fields__ - - -def test_validate_promise(): - config = {"required": 1, "optional": good_catsie} - filled, _, validated = my_registry._fill(config, DefaultsSchema) - assert filled == config - assert validated == {"required": 1, "optional": "meow"} - - -def test_fill_validate_promise(): - config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} - filled, _, validated = my_registry._fill(config, DefaultsSchema) - assert filled["optional"]["cute"] is True - - -def test_fill_invalidate_promise(): - config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} - with pytest.raises(ConfigValidationError): - my_registry._fill(config, HelloIntsSchema) - config["optional"]["whiskers"] = True - with pytest.raises(ConfigValidationError): - my_registry._fill(config, DefaultsSchema) - - -def test_create_registry(): - with pytest.raises(ValueError): - my_registry.create("cats") - my_registry.create("dogs") - assert hasattr(my_registry, "dogs") - assert len(my_registry.dogs.get_all()) == 0 - my_registry.dogs.register("good_boy.v1", func=lambda x: x) - assert len(my_registry.dogs.get_all()) == 1 - with pytest.raises(ValueError): - my_registry.create("dogs") - - -def test_registry_methods(): - with pytest.raises(ValueError): - my_registry.get("dfkoofkds", "catsie.v1") - my_registry.cats.register("catsie.v123")(None) - with pytest.raises(ValueError): - my_registry.get("cats", "catsie.v123") - - -def test_resolve_no_schema(): - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - result = my_registry.resolve({"cfg": config})["cfg"] - assert result["one"] == 1 - assert result["two"] == {"three": "scratch!"} - with pytest.raises(ConfigValidationError): - config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}} - my_registry.resolve(config) - - -def test_resolve_schema(): - class TestBaseSubSchema(BaseModel): - three: str - - class TestBaseSchema(BaseModel): - one: PositiveInt - two: TestBaseSubSchema - - class Config: - extra = "forbid" - - class TestSchema(BaseModel): - cfg: TestBaseSchema - - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - my_registry.resolve({"cfg": config}, schema=TestSchema) - config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - with pytest.raises(ConfigValidationError): - # "one" is not a positive int - my_registry.resolve({"cfg": config}, schema=TestSchema) - config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}} - with pytest.raises(ConfigValidationError): - # "three" is required in subschema - my_registry.resolve({"cfg": config}, schema=TestSchema) - - -def test_resolve_schema_coerced(): - class TestBaseSchema(BaseModel): - test1: str - test2: bool - test3: float - - class TestSchema(BaseModel): - cfg: TestBaseSchema - - config = {"test1": 123, "test2": 1, "test3": 5} - filled = my_registry.fill({"cfg": config}, schema=TestSchema) - result = my_registry.resolve({"cfg": config}, schema=TestSchema) - assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0} - # This only affects the resolved config, not the filled config - assert filled["cfg"] == config - - -def test_read_config(): - byte_string = EXAMPLE_CONFIG.encode("utf8") - cfg = Config().from_bytes(byte_string) - - assert cfg["optimizer"]["beta1"] == 0.9 - assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1 - assert cfg["pipeline"]["parser"]["factory"] == "parser" - assert cfg["pipeline"]["parser"]["model"]["tok2vec"]["width"] == 128 - - -def test_optimizer_config(): - cfg = Config().from_str(OPTIMIZER_CFG) - optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] - assert optimizer.b1 == 0.9 - - -def test_config_to_str(): - cfg = Config().from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_to_str_creates_intermediate_blocks(): - cfg = Config({"optimizer": {"foo": {"bar": 1}}}) - assert ( - cfg.to_str().strip() - == """ -[optimizer] - -[optimizer.foo] -bar = 1 - """.strip() - ) - - -def test_config_roundtrip_bytes(): - cfg = Config().from_str(OPTIMIZER_CFG) - cfg_bytes = cfg.to_bytes() - new_cfg = Config().from_bytes(cfg_bytes) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_roundtrip_disk(): - cfg = Config().from_str(OPTIMIZER_CFG) - with make_tempdir() as path: - cfg_path = path / "config.cfg" - cfg.to_disk(cfg_path) - new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture): - cfg = Config().from_str(OPTIMIZER_CFG) - cfg_path = pathy_fixture / "config.cfg" - cfg.to_disk(cfg_path) - new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_to_str_invalid_defaults(): - """Test that an error is raised if a config contains top-level keys without - a section that would otherwise be interpreted as [DEFAULT] (which causes - the values to be included in *all* other sections). - """ - cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}} - with pytest.raises(ConfigValidationError): - Config(cfg).to_str() - config_str = "[DEFAULT]\none = 1" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_validation_custom_types(): - def complex_args( - rate: StrictFloat, - steps: PositiveInt = 10, # type: ignore - log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", - ): - return None - - my_registry.create("complex") - my_registry.complex("complex.v1")(complex_args) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"} - with pytest.raises(ConfigValidationError): - # steps is not a positive int - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"} - with pytest.raises(ConfigValidationError): - # log_level is not a string matching the regex - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} - with pytest.raises(ConfigValidationError): - # top-level object is promise - my_registry.resolve(cfg) - with pytest.raises(ConfigValidationError): - # top-level object is promise - my_registry.fill(cfg) - cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} - with pytest.raises(ConfigValidationError): - # two constructors - my_registry.resolve({"config": cfg}) - - -def test_validation_no_validate(): - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}} - result = my_registry.resolve({"cfg": config}, validate=False) - filled = my_registry.fill({"cfg": config}, validate=False) - assert result["cfg"]["one"] == 1 - assert result["cfg"]["two"] == {"three": "scratch!"} - assert filled["cfg"]["two"]["three"]["evil"] == "false" - assert filled["cfg"]["two"]["three"]["cute"] is True - - -def test_validation_fill_defaults(): - config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}} - result = my_registry.fill(config, validate=False) - assert len(result["cfg"]["two"]) == 3 - with pytest.raises(ConfigValidationError): - # Required arg "evil" is not defined - my_registry.fill(config) - config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}} - # Fill in with new defaults - result = my_registry.fill(config) - assert len(result["cfg"]["two"]) == 4 - assert result["cfg"]["two"]["evil"] is False - assert result["cfg"]["two"]["cute"] is True - assert result["cfg"]["two"]["cute_level"] == 1 - - -def test_make_config_positional_args(): - @my_registry.cats("catsie.v567") - def catsie_567(*args: Optional[str], foo: str = "bar"): - assert args[0] == "^_^" - assert args[1] == "^(*.*)^" - assert foo == "baz" - return args[0] - - args = ["^_^", "^(*.*)^"] - cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}} - assert my_registry.resolve(cfg)["config"] == "^_^" - - -def test_make_config_positional_args_complex(): - @my_registry.cats("catsie.v890") - def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]): - assert args[0] == 123 - return args[0] - - cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}} - assert my_registry.resolve(cfg)["config"] == 123 - cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}} - with pytest.raises(ConfigValidationError): - # "True" is not a valid boolean or positive int - my_registry.resolve(cfg) - - -def test_positional_args_to_from_string(): - cfg = """[a]\nb = 1\n* = ["foo","bar"]""" - assert Config().from_str(cfg).to_str() == cfg - cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1""" - assert Config().from_str(cfg).to_str() == cfg - - @my_registry.cats("catsie.v666") - def catsie_666(*args, meow=False): - return args - - cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false""" - resolved = my_registry.resolve(Config().from_str(cfg)) - assert resolved == {"a": ("foo", "bar")} - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1""" - resolved = my_registry.resolve(Config().from_str(cfg)) - assert resolved == {"a": ({"x": 1},)} - - @my_registry.cats("catsie.v777") - def catsie_777(y: int = 1): - return "meow" * y - - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\"""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1""" - assert filled == expected - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3""" - result = my_registry.resolve(Config().from_str(cfg)) - assert result == {"a": ("meowmeowmeow",)} - - def test_make_config_positional_args_dicts(): cfg = { "hyper_params": {"n_hidden": 512, "dropout": 0.2, "learn_rate": 0.001}, @@ -511,51 +156,6 @@ def test_make_config_positional_args_dicts(): model.finish_update(resolved["optimizer"]) -def test_validation_generators_iterable(): - @my_registry.optimizers("test_optimizer.v1") - def test_optimizer_v1(rate: float) -> None: - return None - - @my_registry.schedules("test_schedule.v1") - def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]: - while True: - yield some_value - - config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}} - my_registry.resolve(config) - - -def test_validation_unset_type_hints(): - """Test that unset type hints are handled correctly (and treated as Any).""" - - @my_registry.optimizers("test_optimizer.v2") - def test_optimizer_v2(rate, steps: int = 10) -> None: - return None - - config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}} - my_registry.resolve(config) - - -def test_validation_bad_function(): - @my_registry.optimizers("bad.v1") - def bad() -> None: - raise ValueError("This is an error in the function") - return None - - @my_registry.optimizers("good.v1") - def good() -> None: - return None - - # Bad function - config = {"test": {"@optimizers": "bad.v1"}} - with pytest.raises(ValueError): - my_registry.resolve(config) - # Bad function call - config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(config) - - def test_objects_from_config(): config = { "optimizer": { @@ -583,93 +183,6 @@ def decaying(base_rate: float, repeat: int) -> List[float]: assert optimizer.learn_rate == 0.001 -def test_partials_from_config(): - """Test that functions registered with partial applications are handled - correctly (e.g. initializers).""" - name = "uniform_init.v1" - cfg = {"test": {"@initializers": name, "lo": -0.2}} - func = my_registry.resolve(cfg)["test"] - assert hasattr(func, "__call__") - # The partial will still have lo as an arg, just with default - assert len(inspect.signature(func).parameters) == 4 - # Make sure returned partial function has correct value set - assert inspect.signature(func).parameters["lo"].default == -0.2 - # Actually call the function and verify - func(NumpyOps(), (2, 3)) - # Make sure validation still works - bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(bad_cfg) - bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(bad_cfg) - - -def test_partials_from_config_nested(): - """Test that partial functions are passed correctly to other registered - functions that consume them (e.g. initializers -> layers).""" - - def test_initializer(a: int, b: int = 1) -> int: - return a * b - - @my_registry.initializers("test_initializer.v1") - def configure_test_initializer(b: int = 1) -> Callable[[int], int]: - return partial(test_initializer, b=b) - - @my_registry.layers("test_layer.v1") - def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]: - return lambda x: x + init(c) - - cfg = { - "@layers": "test_layer.v1", - "c": 5, - "init": {"@initializers": "test_initializer.v1", "b": 10}, - } - func = my_registry.resolve({"test": cfg})["test"] - assert func(1) == 51 - assert func(100) == 150 - - -def test_validate_generator(): - """Test that generator replacement for validation in config doesn't - actually replace the returned value.""" - - @my_registry.schedules("test_schedule.v2") - def test_schedule(): - while True: - yield 10 - - cfg = {"@schedules": "test_schedule.v2"} - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: - return rate - - cfg = { - "@optimizers": "test_optimizer.v2", - "rate": {"@schedules": "test_schedule.v2"}, - } - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: - return schedules["rate"] - - cfg = { - "@optimizers": "test_optimizer.v3", - "schedules": {"rate": {"@schedules": "test_schedule.v2"}}, - } - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: - return schedules[0] - - def test_handle_generic_model_type(): """Test that validation can handle checks against arbitrary generic types in function argument annotations.""" @@ -685,760 +198,6 @@ def my_transform(model: Model[int, int]): assert model.name == "transformed_model" -@pytest.mark.parametrize( - "cfg", - [ - "[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3", - "[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3", - ], -) -def test_handle_error_duplicate_keys(cfg): - """This would cause very cryptic error when interpreting config. - (TypeError: 'X' object does not support item assignment) - """ - with pytest.raises(ConfigValidationError): - Config().from_str(cfg) - - -@pytest.mark.parametrize( - "cfg,is_valid", - [("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)], -) -def test_cant_expand_undefined_block(cfg, is_valid): - """Test that you can't expand a block that hasn't been created yet. This - comes up when you typo a name, and if we allow expansion of undefined blocks, - it's very hard to create good errors for those typos. - """ - if is_valid: - Config().from_str(cfg) - else: - with pytest.raises(ConfigValidationError): - Config().from_str(cfg) - - -def test_fill_config_overrides(): - config = { - "cfg": { - "one": 1, - "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, - } - } - overrides = {"cfg.two.three.evil": False} - result = my_registry.fill(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"]["evil"] is False - # Test that promises can be overwritten as well - overrides = {"cfg.two.three": 3} - result = my_registry.fill(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == 3 - # Test that value can be overwritten with promises and that the result is - # interpreted and filled correctly - overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} - result = my_registry.fill(config, overrides=overrides) - assert result["cfg"]["two"] is None - assert result["cfg"]["one"]["@cats"] == "catsie.v1" - assert result["cfg"]["one"]["evil"] is False - assert result["cfg"]["one"]["cute"] is True - # Overwriting with wrong types should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": 20} - my_registry.fill(config, overrides=overrides, validate=True) - # Overwriting with incomplete promises should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} - my_registry.fill(config, overrides=overrides) - # Overrides that don't match config should raise error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": False, "two.four": True} - my_registry.fill(config, overrides=overrides, validate=True) - with pytest.raises(ConfigValidationError): - overrides = {"cfg.five": False} - my_registry.fill(config, overrides=overrides, validate=True) - - -def test_resolve_overrides(): - config = { - "cfg": { - "one": 1, - "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, - } - } - overrides = {"cfg.two.three.evil": False} - result = my_registry.resolve(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == "meow" - # Test that promises can be overwritten as well - overrides = {"cfg.two.three": 3} - result = my_registry.resolve(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == 3 - # Test that value can be overwritten with promises - overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} - result = my_registry.resolve(config, overrides=overrides) - assert result["cfg"]["one"] == "meow" - assert result["cfg"]["two"] is None - # Overwriting with wrong types should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": 20} - my_registry.resolve(config, overrides=overrides, validate=True) - # Overwriting with incomplete promises should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} - my_registry.resolve(config, overrides=overrides) - # Overrides that don't match config should raise error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": False, "cfg.two.four": True} - my_registry.resolve(config, overrides=overrides, validate=True) - with pytest.raises(ConfigValidationError): - overrides = {"cfg.five": False} - my_registry.resolve(config, overrides=overrides, validate=True) - - -@pytest.mark.parametrize( - "prop,expected", - [("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)], -) -def test_is_in_config(prop, expected): - config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}} - assert my_registry._is_in_config(prop, config) is expected - - -def test_resolve_prefilled_values(): - class Language(object): - def __init__(self): - ... - - @my_registry.optimizers("prefilled.v1") - def prefilled(nlp: Language, value: int = 10): - return (nlp, value) - - # Passing an instance of Language here via the config is bad, since it - # won't serialize to a string, but we still test for it - config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}} - resolved = my_registry.resolve(config, validate=True) - result = resolved["test"] - assert isinstance(result[0], Language) - assert result[1] == 50 - - -def test_fill_config_dict_return_type(): - """Test that a registered function returning a dict is handled correctly.""" - - @my_registry.cats.register("catsie_with_dict.v1") - def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]: - return {"not_evil": not evil} - - config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10} - result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"] - assert result["evil"] is False - assert "not_evil" not in result - result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"] - assert result["not_evil"] is True - - -def test_deepcopy_config(): - config = Config({"a": 1, "b": {"c": 2, "d": 3}}) - copied = config.copy() - # Same values but not same object - assert config == copied - assert config is not copied - # Check for error if value can't be pickled/deepcopied - config = Config({"a": 1, "b": numpy}) - with pytest.raises(ValueError): - config.copy() - - -def test_config_to_str_simple_promises(): - """Test that references to function registries without arguments are - serialized inline as dict.""" - config_str = """[section]\nsubsection = {"@registry":"value"}""" - config = Config().from_str(config_str) - assert config["section"]["subsection"]["@registry"] == "value" - assert config.to_str() == config_str - - -def test_config_from_str_invalid_section(): - config_str = """[a]\nb = null\n\n[a.b]\nc = 1""" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1""" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_config_to_str_order(): - """Test that Config.to_str orders the sections.""" - config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}} - expected = ( - "[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5" - ) - config = Config(config) - assert config.to_str() == expected - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_interpolation(d): - """Test that config values are interpolated correctly. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}""" - assert Config().from_str(c_str)["b"]["bar"] == "hello" - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!""" - assert Config().from_str(c_str)["b"]["bar"] == "hello!" - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\"""" - assert Config().from_str(c_str)["b"]["bar"] == "hello!" - c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!""" - assert Config().from_str(c_str)["b"]["bar"] == "15!" - c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}""" - assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"] - # Interpolation within the same section - c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\"""" - assert Config().from_str(c_str)["a"]["bar"] == "x" - assert Config().from_str(c_str)["a"]["baz"] == "xy" - - -def test_config_interpolation_lists(): - # Test that lists are preserved correctly - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]""" - config = Config().from_str(c_str, interpolate=False) - assert config["c"]["d"] == ["hello ${a.b}", "world"] - config = config.interpolate() - assert config["c"]["d"] == ["hello 1", "world"] - c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]""" - config = Config().from_str(c_str) - assert config["c"]["d"] == [1, "hello 1", "world"] - config = Config().from_str(c_str, interpolate=False) - # NOTE: This currently doesn't work, because we can't know how to JSON-load - # the uninterpolated list [${a.b}]. - # assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"] - # config = config.interpolate() - # assert config["c"]["d"] == [1, "hello 1", "world"] - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]""" - config = Config().from_str(c_str) - assert config["c"]["d"] == ["hello", {"b": 1}] - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]""" - config = Config().from_str(config_str) - assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}] - config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_interpolation_sections(d): - """Test that config sections are interpolated correctly. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - # Simple block references - c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}""" - config = Config().from_str(c_str) - assert config["b"]["c"] == config["a"] - # References with non-string values - c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]""" - config = Config().from_str(c_str) - assert config["a"]["x"]["y"] == config["a"]["b"] - # Multiple references in the same string - c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\"""" - config = Config().from_str(c_str) - assert config["b"]["z"] == "string/10" - # Non-string references in string (converted to string) - c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\"""" - config = Config().from_str(c_str) - assert config["b"]["y"] == 'result: ["hello", "world"]' - # References to sections referencing sections - c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}""" - config = Config().from_str(c_str) - assert config["b"]["bar"] == config["a"] - assert config["c"]["baz"] == config["b"] - # References to section values referencing other sections - c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}""" - config = Config().from_str(c_str) - assert config["c"]["baz"] == config["b"]["bar"] - # References to sections with subsections - c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}""" - config = Config().from_str(c_str) - assert config["c"]["baz"] == config["a"] - # Infinite recursion - c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}""" - config = Config().from_str(c_str) - assert config["a"]["b"]["bar"] == config["a"] - c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}""" - # We can't reference not-yet interpolated subsections - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - # Generally invalid references - c_str = f"""[a]\nfoo = ${{b{d}bar}}""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - # We can't reference sections or promises within strings - c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - - -def test_config_from_str_overrides(): - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}""" - # Basic value substitution - overrides = {"a.b": 10, "a.c.d": 20} - config = Config().from_str(config_str, overrides=overrides) - assert config["a"]["b"] == 10 - assert config["a"]["c"]["d"] == 20 - assert config["a"]["c"]["e"] == 3 - # Valid values that previously weren't in config - config = Config().from_str(config_str, overrides={"a.c.f": 100}) - assert config["a"]["c"]["d"] == 2 - assert config["a"]["c"]["e"] == 3 - assert config["a"]["c"]["f"] == 100 - # Invalid keys and sections - with pytest.raises(ConfigValidationError): - Config().from_str(config_str, overrides={"f": 10}) - # This currently isn't expected to work, because the dict in f.g is not - # interpreted as a section while the config is still just the configparser - with pytest.raises(ConfigValidationError): - Config().from_str(config_str, overrides={"f.g.x": "z"}) - # With variables (values) - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}""" - config = Config().from_str(config_str, overrides={"a.b": 10}) - assert config["a"]["b"] == 10 - assert config["a"]["c"]["e"] == 10 - # With variables (sections) - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}""" - config = Config().from_str(config_str, overrides={"a.c.d": 20}) - assert config["a"]["c"]["d"] == 20 - assert config["e"]["f"] == {"d": 20} - - -def test_config_reserved_aliases(): - """Test that the auto-generated pydantic schemas auto-alias reserved - attributes like "validate" that would otherwise cause NameError.""" - - @my_registry.cats("catsie.with_alias") - def catsie_with_alias(validate: StrictBool = False): - return validate - - cfg = {"@cats": "catsie.with_alias", "validate": True} - resolved = my_registry.resolve({"test": cfg}) - filled = my_registry.fill({"test": cfg}) - assert resolved["test"] is True - assert filled["test"] == cfg - cfg = {"@cats": "catsie.with_alias", "validate": 20} - with pytest.raises(ConfigValidationError): - my_registry.resolve({"test": cfg}) - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_no_interpolation(d): - """Test that interpolation is correctly preserved. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}""" - config = Config().from_str(c_str, interpolate=False) - assert not config.is_interpolated - assert config["c"]["d"] == f"${{a{d}b}}" - assert config["c"]["e"] == f'"hello${{a{d}b}}"' - assert config["c"]["f"] == "${a}" - config2 = Config().from_str(config.to_str(), interpolate=True) - assert config2.is_interpolated - assert config2["c"]["d"] == 1 - assert config2["c"]["e"] == "hello1" - assert config2["c"]["f"] == {"b": 1} - config3 = config.interpolate() - assert config3.is_interpolated - assert config3["c"]["d"] == 1 - assert config3["c"]["e"] == "hello1" - assert config3["c"]["f"] == {"b": 1} - # Bad non-serializable value - cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}} - with pytest.raises(ConfigValidationError): - Config(cfg).interpolate() - - -def test_config_no_interpolation_registry(): - config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}""" - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - assert config["b"]["evil"] == "${a:bad}" - assert config["c"]["d"] == "${b}" - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["b"] == "scratch!" - assert resolved["c"]["d"] == "scratch!" - assert filled["b"]["evil"] == "${a:bad}" - assert filled["b"]["cute"] is True - assert filled["c"]["d"] == "${b}" - interpolated = filled.interpolate() - assert interpolated.is_interpolated - assert interpolated["b"]["evil"] is True - assert interpolated["c"]["d"] == interpolated["b"] - config = Config().from_str(config_str, interpolate=True) - assert config.is_interpolated - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["b"] == "scratch!" - assert resolved["c"]["d"] == "scratch!" - assert filled["b"]["evil"] is True - assert filled["c"]["d"] == filled["b"] - # Resolving a non-interpolated filled config - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - filled = my_registry.fill(config) - assert not filled.is_interpolated - assert filled["c"]["d"] == "${b}" - resolved = my_registry.resolve(filled) - assert resolved["c"]["d"] == "scratch!" - - -def test_config_deep_merge(): - config = {"a": "hello", "b": {"c": "d"}} - defaults = {"a": "world", "b": {"c": "e", "f": "g"}} - merged = Config(defaults).merge(config) - assert len(merged) == 2 - assert merged["a"] == "hello" - assert merged["b"] == {"c": "d", "f": "g"} - config = {"a": "hello", "b": {"@test": "x", "foo": 1}} - defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2} - assert merged["c"] == 100 - config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "x", "foo": 1} - assert merged["c"] == 100 - # Test that leaving out the factory just adds to existing - config = {"a": "hello", "b": {"foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2} - assert merged["c"] == 100 - # Test that switching to a different factory prevents the default from being added - config = {"a": "hello", "b": {"@foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@bar": "y"}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@foo": 1} - assert merged["c"] == 100 - config = {"a": "hello", "b": {"@foo": 1}, "c": 100} - defaults = {"a": "world", "b": "y"} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@foo": 1} - assert merged["c"] == 100 - - -def test_config_deep_merge_variables(): - config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}""" - defaults_str = """[a]\nx = 100\n\n[d]\ny = 500""" - config = Config().from_str(config_str, interpolate=False) - defaults = Config().from_str(defaults_str) - merged = defaults.merge(config) - assert merged["a"] == {"b": 1, "c": 2, "x": 100} - assert merged["d"] == {"e": "${a:b}", "y": 500} - assert merged.interpolate()["d"] == {"e": 1, "y": 500} - # With variable in defaults: overwritten by new value - config = Config().from_str("""[a]\nb= 1\nc = 2""") - defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False) - merged = defaults.merge(config) - assert merged["a"]["c"] == 2 - - -def test_config_to_str_roundtrip(): - cfg = {"cfg": {"foo": False}} - config_str = Config(cfg).to_str() - assert config_str == "[cfg]\nfoo = false" - config = Config().from_str(config_str) - assert dict(config) == cfg - cfg = {"cfg": {"foo": "false"}} - config_str = Config(cfg).to_str() - assert config_str == '[cfg]\nfoo = "false"' - config = Config().from_str(config_str) - assert dict(config) == cfg - # Bad non-serializable value - cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}} - config = Config(cfg) - with pytest.raises(ConfigValidationError): - config.to_str() - # Roundtrip with variables: preserve variables correctly (quoted/unquoted) - config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\"""" - config = Config().from_str(config_str, interpolate=False) - assert config.to_str() == config_str - - -def test_config_is_interpolated(): - """Test that a config object correctly reports whether it's interpolated.""" - config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}""" - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - config = config.merge(Config({"x": {"y": "z"}})) - assert not config.is_interpolated - config = Config(config) - assert not config.is_interpolated - config = config.interpolate() - assert config.is_interpolated - config = config.merge(Config().from_str(config_str, interpolate=False)) - assert not config.is_interpolated - - -@pytest.mark.parametrize( - "section_order,expected_str,expected_keys", - [ - # fmt: off - ([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]), - (["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]), - (["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"]) - # fmt: on - ], -) -def test_config_serialize_custom_sort(section_order, expected_str, expected_keys): - cfg = { - "j": {"k": 6}, - "a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}}, - "h": {"i": 5}, - } - cfg_str = Config(cfg).to_str() - assert Config(cfg, section_order=section_order).to_str() == expected_str - keys = list(Config(section_order=section_order).from_str(cfg_str).keys()) - assert keys == expected_keys - keys = list(Config(cfg, section_order=section_order).keys()) - assert keys == expected_keys - - -def test_config_custom_sort_preserve(): - """Test that sort order is preserved when merging and copying configs, - or when configs are filled and resolved.""" - cfg = {"x": {}, "y": {}, "z": {}} - section_order = ["y", "z", "x"] - expected = "[y]\n\n[z]\n\n[x]" - config = Config(cfg, section_order=section_order) - assert config.to_str() == expected - config2 = config.copy() - assert config2.to_str() == expected - config3 = config.merge({"a": {}}) - assert config3.to_str() == f"{expected}\n\n[a]" - config4 = Config(config) - assert config4.to_str() == expected - config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2""" - section_order = ["c", "a", "t"] - config5 = Config(section_order=section_order).from_str(config_str) - assert list(config5.keys()) == section_order - filled = my_registry.fill(config5) - assert filled.section_order == section_order - - -def test_config_pickle(): - config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"]) - data = pickle.dumps(config) - config_new = pickle.loads(data) - assert config_new == {"foo": "bar"} - assert config_new.section_order == ["foo", "bar", "baz"] - - -def test_config_fill_extra_fields(): - """Test that filling a config from a schema removes extra fields.""" - - class TestSchemaContent(BaseModel): - a: str - b: int - - class Config: - extra = "forbid" - - class TestSchema(BaseModel): - cfg: TestSchemaContent - - config = Config({"cfg": {"a": "1", "b": 2, "c": True}}) - with pytest.raises(ConfigValidationError): - my_registry.fill(config, schema=TestSchema) - filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - config2 = config.interpolate() - filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False) - filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - - class TestSchemaContent2(BaseModel): - a: str - b: int - - class Config: - extra = "allow" - - class TestSchema2(BaseModel): - cfg: TestSchemaContent2 - - filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2, "c": True} - - -def test_config_validation_error_custom(): - class Schema(BaseModel): - hello: int - world: int - - config = {"hello": 1, "world": "hi!"} - with pytest.raises(ConfigValidationError) as exc_info: - my_registry._fill(config, Schema) - e1 = exc_info.value - assert e1.title == "Config validation error" - assert e1.desc is None - assert not e1.parent - assert e1.show_config is True - assert len(e1.errors) == 1 - assert e1.errors[0]["loc"] == ("world",) - assert e1.errors[0]["msg"] == "value is not a valid integer" - assert e1.errors[0]["type"] == "type_error.integer" - assert e1.error_types == set(["type_error.integer"]) - # Create a new error with overrides - title = "Custom error" - desc = "Some error description here" - e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False) - assert e2.errors == e1.errors - assert e2.error_types == e1.error_types - assert e2.title == title - assert e2.desc == desc - assert e2.show_config is False - assert e1.text != e2.text - - -def test_config_parsing_error(): - config_str = "[a]\nb c" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_config_fill_without_resolve(): - class BaseSchema(BaseModel): - catsie: int - - config = {"catsie": {"@cats": "catsie.v1", "evil": False}} - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["catsie"] == "meow" - assert filled["catsie"]["cute"] is True - with pytest.raises(ConfigValidationError): - my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema) - assert filled2["catsie"]["cute"] is True - resolved = my_registry.resolve(filled2) - assert resolved["catsie"] == "meow" - # With unavailable function - class BaseSchema2(BaseModel): - catsie: Any - other: int = 12 - - config = {"catsie": {"@cats": "dog", "evil": False}} - filled3 = my_registry.fill(config, schema=BaseSchema2) - assert filled3["catsie"] == config["catsie"] - assert filled3["other"] == 12 - - -def test_config_dataclasses(): - @my_registry.cats("catsie.ragged") - def catsie_ragged(arg: Ragged): - return arg - - data = numpy.zeros((20, 4), dtype="f") - lengths = numpy.array([4, 2, 8, 1, 4], dtype="i") - ragged = Ragged(data, lengths) - config = {"cfg": {"@cats": "catsie.ragged", "arg": ragged}} - result = my_registry.resolve(config)["cfg"] - assert isinstance(result, Ragged) - assert list(result._get_starts_ends()) == [0, 4, 6, 14, 15, 19] - - -@pytest.mark.parametrize( - "greeting,value,expected", - [ - # simple substitution should go fine - [342, "${vars.a}", int], - ["342", "${vars.a}", str], - ["everyone", "${vars.a}", str], - ], -) -def test_config_interpolates(greeting, value, expected): - str_cfg = f""" - [project] - my_par = {value} - - [vars] - a = "something" - """ - overrides = {"vars.a": greeting} - cfg = Config().from_str(str_cfg, overrides=overrides) - assert type(cfg["project"]["my_par"]) == expected - - -@pytest.mark.parametrize( - "greeting,value,expected", - [ - # fmt: off - # simple substitution should go fine - ["hello 342", "${vars.a}", "hello 342"], - ["hello everyone", "${vars.a}", "hello everyone"], - ["hello tout le monde", "${vars.a}", "hello tout le monde"], - ["hello 42", "${vars.a}", "hello 42"], - # substituting an element in a list - ["hello 342", "[1, ${vars.a}, 3]", "hello 342"], - ["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"], - ["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"], - ["hello 42", "[1, ${vars.a}, 3]", "hello 42"], - # substituting part of a string - [342, "hello ${vars.a}", "hello 342"], - ["everyone", "hello ${vars.a}", "hello everyone"], - ["tout le monde", "hello ${vars.a}", "hello tout le monde"], - pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail), - # substituting part of a implicit string inside a list - [342, "[1, hello ${vars.a}, 3]", "hello 342"], - ["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"], - ["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"], - pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail), - # substituting part of a explicit string inside a list - [342, "[1, 'hello ${vars.a}', '3']", "hello 342"], - ["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"], - ["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"], - pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail), - # more complicated example - [342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"], - ["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"], - ["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"], - pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail), - # fmt: on - ], -) -def test_config_overrides(greeting, value, expected): - str_cfg = f""" - [project] - commands = {value} - - [vars] - a = "world" - """ - overrides = {"vars.a": greeting} - assert "${vars.a}" in str_cfg - cfg = Config().from_str(str_cfg, overrides=overrides) - assert expected in str(cfg) - - def test_arg_order_is_preserved(): str_cfg = """ [model] diff --git a/thinc/tests/test_loss.py b/thinc/tests/test_loss.py index 710a88d61..75206d240 100644 --- a/thinc/tests/test_loss.py +++ b/thinc/tests/test_loss.py @@ -168,7 +168,9 @@ def test_sequence_categorical_crossentropy(guesses, labels, names): assert d_scores1[1][0] == pytest.approx(0.4, eps) assert d_scores1[1][1] == pytest.approx(-0.4, eps) # The normalization divides the difference (e.g. 0.4) by the number of seqs - d_scores = SequenceCategoricalCrossentropy(normalize=True, names=names).get_grad(guesses, labels) + d_scores = SequenceCategoricalCrossentropy(normalize=True, names=names).get_grad( + guesses, labels + ) d_scores1 = d_scores[0] d_scores2 = d_scores[1] @@ -189,7 +191,9 @@ def test_sequence_categorical_crossentropy(guesses, labels, names): assert d_scores2[0][0] == pytest.approx(0.1, eps) assert d_scores2[0][1] == pytest.approx(-0.35, eps) - loss = SequenceCategoricalCrossentropy(normalize=True, names=names).get_loss(guesses, labels) + loss = SequenceCategoricalCrossentropy(normalize=True, names=names).get_loss( + guesses, labels + ) assert loss == pytest.approx(1.09, eps) @@ -200,9 +204,9 @@ def test_sequence_categorical_crossentropy(guesses, labels, names): ], ) def test_sequence_categorical_missing_negative(guesses, labels, names): - d_scores = SequenceCategoricalCrossentropy(normalize=False, names=names, neg_prefix="!", missing_value="").get_grad( - guesses, labels - ) + d_scores = SequenceCategoricalCrossentropy( + normalize=False, names=names, neg_prefix="!", missing_value="" + ).get_grad(guesses, labels) d_scores0 = d_scores[0] # [0.1, 0.5, 0.6] should be A @@ -292,8 +296,16 @@ def test_cosine_unmatched(): ("SequenceCategoricalCrossentropy.v1", {}, ([scores0], [labels0])), ("CategoricalCrossentropy.v2", {"neg_prefix": "!"}, (scores0, labels0)), ("CategoricalCrossentropy.v3", {"neg_prefix": "!"}, (scores0, labels0)), - ("SequenceCategoricalCrossentropy.v2", {"neg_prefix": "!"}, ([scores0], [labels0])), - ("SequenceCategoricalCrossentropy.v3", {"neg_prefix": "!"}, ([scores0], [labels0])), + ( + "SequenceCategoricalCrossentropy.v2", + {"neg_prefix": "!"}, + ([scores0], [labels0]), + ), + ( + "SequenceCategoricalCrossentropy.v3", + {"neg_prefix": "!"}, + ([scores0], [labels0]), + ), ("L2Distance.v1", {}, (scores0, scores0)), ( "CosineDistance.v1", diff --git a/thinc/tests/test_serialize.py b/thinc/tests/test_serialize.py index f3a937c34..b89fc2d94 100644 --- a/thinc/tests/test_serialize.py +++ b/thinc/tests/test_serialize.py @@ -55,7 +55,7 @@ def test_simple_model_roundtrip_bytes(): def test_simple_model_roundtrip_bytes_length(): - """ Ensure that serialization of non-initialized weight matrices goes fine """ + """Ensure that serialization of non-initialized weight matrices goes fine""" model1 = Maxout(5, 10, nP=2) model2 = Maxout(5, 10, nP=2) @@ -186,7 +186,7 @@ def test_simple_model_can_from_dict(): assert model.can_from_dict(model_dict) # Test check without initialize assert Maxout(5, 10, nP=2).can_from_dict(model_dict) - # Test not-strict check + # Test not-strict check assert not Maxout(10, 5, nP=2).can_from_dict(model_dict) assert Maxout(5, nP=2).can_from_dict(model_dict) diff --git a/thinc/tests/test_util.py b/thinc/tests/test_util.py index b51f23a19..8d2d0058d 100644 --- a/thinc/tests/test_util.py +++ b/thinc/tests/test_util.py @@ -3,11 +3,20 @@ from hypothesis import given from thinc.api import get_width, Ragged, Padded from thinc.util import get_array_module, is_numpy_array, to_categorical +from thinc.util import is_cupy_array from thinc.util import convert_recursive from thinc.types import ArgsKwargs from . import strategies +ALL_XP = [numpy] +try: + import cupy + + ALL_XP.append(cupy) +except ImportError: + pass + @pytest.mark.parametrize( "obj,width", @@ -39,14 +48,30 @@ def test_get_width_fail(obj): get_width(obj) -def test_array_module_cpu_gpu_helpers(): - xp = get_array_module(0) - assert hasattr(xp, "ndarray") - assert is_numpy_array(numpy.zeros((1, 2))) - assert not is_numpy_array((1, 2)) +@pytest.mark.parametrize("xp", ALL_XP) +def test_array_module_cpu_gpu_helpers(xp): + error = ( + "Only numpy and cupy arrays are supported" + ", but found instead. If " + "get_array_module module wasn't called " + "directly, this might indicate a bug in Thinc." + ) + with pytest.raises(ValueError, match=error): + get_array_module(0) + zeros = xp.zeros((1, 2)) + xp_ = get_array_module(zeros) + assert xp_ == xp + if xp == numpy: + assert is_numpy_array(zeros) + assert not is_numpy_array((1, 2)) + else: + assert is_cupy_array(zeros) + assert not is_cupy_array((1, 2)) -@given(label_smoothing=strategies.floats(min_value=0.0, max_value=0.5)) +@given( + label_smoothing=strategies.floats(min_value=0.0, max_value=0.5, exclude_max=True) +) def test_to_categorical(label_smoothing): # Test without n_classes one_hot = to_categorical(numpy.asarray([1, 2], dtype="i")) @@ -113,6 +138,12 @@ def test_to_categorical(label_smoothing): ): to_categorical(numpy.asarray([0, 0, 0]), label_smoothing=0.01), + with pytest.raises(ValueError, match=r"label_smoothing parameter"): + to_categorical(numpy.asarray([0, 1, 2, 3, 4]), label_smoothing=0.8) + + with pytest.raises(ValueError, match=r"label_smoothing parameter"): + to_categorical(numpy.asarray([0, 1, 2, 3, 4]), label_smoothing=0.88) + def test_convert_recursive(): is_match = lambda obj: obj == "foo" diff --git a/thinc/types.py b/thinc/types.py index 74498d159..629a79e82 100644 --- a/thinc/types.py +++ b/thinc/types.py @@ -4,12 +4,11 @@ from dataclasses import dataclass import numpy import sys +from .compat import has_cupy, cupy -try: - import cupy - +if has_cupy: get_array_module = cupy.get_array_module -except (ImportError, AttributeError): +else: get_array_module = lambda obj: numpy # Use typing_extensions for Python versions < 3.8 @@ -47,6 +46,7 @@ ArrayT = TypeVar("ArrayT") SelfT = TypeVar("SelfT") Array1dT = TypeVar("Array1dT", bound="Array1d") +FloatsXdT = TypeVar("FloatsXdT", "Floats1d", "Floats2d", "Floats3d", "Floats4d") # These all behave the same as far as indexing is concerned Slicish = Union[slice, List[int], "ArrayXd"] @@ -162,7 +162,7 @@ def __bytes__(self) -> bytes: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... def __copy__(self, order: str = ...): ... - def __deepcopy__(self, memo: dict) -> ArrayT: ... + def __deepcopy__(self: SelfT, memo: dict) -> SelfT: ... def __lt__(self, other): ... def __le__(self, other): ... def __eq__(self, other): ... @@ -224,7 +224,7 @@ def clip(self, a_min: Any, a_max: Any, out: Optional[ArrayT]) -> ArrayT: ... def max(self, axis: int = -1, out: Optional[ArrayT] = None) -> ArrayT: ... # def mean(self, axis: int = -1, dtype: Optional[DTypes] = None, out: Optional[SelfT] = None, keepdims: bool = False) -> "Array": ... def min(self, axis: int = -1, out: Optional[ArrayT] = None) -> ArrayT: ... - def nonzero(self) -> ArrayT: ... + def nonzero(self: SelfT) -> SelfT: ... def prod(self, axis: int = -1, dtype: Optional[DTypes] = None, out: Optional[ArrayT] = None, keepdims: bool = False) -> ArrayT: ... def round(self, decimals: int = 0, out: Optional[ArrayT] = None) -> ArrayT: ... # def sum(self, axis: int = -1, dtype: Optional[DTypes] = None, out: Optional[ArrayT] = None, keepdims: bool = False) -> ArrayT: ... @@ -317,7 +317,7 @@ class Floats1d(_Array1d, _Floats): @classmethod def __get_validators__(cls): - """Runtine validation for pydantic.""" + """Runtime validation for pydantic.""" yield lambda v: validate_array(v, ndim=1, dtype="f") def __iter__(self) -> Iterator[float]: ... diff --git a/thinc/util.py b/thinc/util.py index dacfb09c2..43fb115d2 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -13,54 +13,13 @@ import contextlib from contextvars import ContextVar from dataclasses import dataclass +from .compat import has_cupy, has_mxnet, has_torch, has_tensorflow +from .compat import has_cupy_gpu, has_torch_cuda_gpu, has_gpu +from .compat import has_torch_mps_gpu +from .compat import torch, cupy, tensorflow as tf, mxnet as mx, cupy_from_dlpack DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False) -try: # pragma: no cover - import cupy - - has_cupy = True -except (ImportError, AttributeError): - cupy = None - has_cupy = False - - -try: # pragma: no cover - import torch - from torch import tensor - import torch.utils.dlpack - - has_torch = True - has_torch_gpu = torch.cuda.device_count() != 0 - torch_version = Version(str(torch.__version__)) - has_torch_amp = ( - torch_version >= Version("1.9.0") - and not torch.cuda.amp.common.amp_definitely_not_available() - ) -except ImportError: # pragma: no cover - has_torch = False - has_torch_gpu = False - has_torch_amp = False - torch_version = Version("0.0.0") - -try: # pragma: no cover - import tensorflow.experimental.dlpack - import tensorflow as tf - - has_tensorflow = True - has_tensorflow_gpu = len(tf.config.get_visible_devices("GPU")) > 0 -except ImportError: # pragma: no cover - has_tensorflow = False - has_tensorflow_gpu = False - - -try: # pragma: no cover - import mxnet as mx - - has_mxnet = True -except ImportError: # pragma: no cover - has_mxnet = False - from .types import ArrayXd, ArgsKwargs, Ragged, Padded, FloatsXd, IntsXd # noqa: E402 from . import types # noqa: E402 from typing import TYPE_CHECKING @@ -69,22 +28,40 @@ from .api import Ops +def get_torch_default_device() -> "torch.device": + if torch is None: + raise ValueError("Cannot get default Torch device when Torch is not available.") + + from .backends import get_current_ops + from .backends.cupy_ops import CupyOps + from .backends.mps_ops import MPSOps + + ops = get_current_ops() + if isinstance(ops, CupyOps): + device_id = torch.cuda.current_device() + return torch.device(f"cuda:{device_id}") + elif isinstance(ops, MPSOps): + return torch.device("mps") + + return torch.device("cpu") + + def get_array_module(arr): # pragma: no cover - if is_cupy_array(arr): + if is_numpy_array(arr): + return numpy + elif is_cupy_array(arr): return cupy else: - return numpy + raise ValueError( + "Only numpy and cupy arrays are supported" + f", but found {type(arr)} instead. If " + "get_array_module module wasn't called " + "directly, this might indicate a bug in Thinc." + ) def gpu_is_available(): - if not has_cupy: - return False - - try: - cupy.cuda.runtime.getDeviceCount() - return True - except cupy.cuda.runtime.CUDARuntimeError: - return False + return has_gpu def fix_random_seed(seed: int = 0) -> None: # pragma: no cover @@ -93,9 +70,9 @@ def fix_random_seed(seed: int = 0) -> None: # pragma: no cover numpy.random.seed(seed) if has_torch: torch.manual_seed(seed) - if has_cupy and gpu_is_available(): + if has_cupy_gpu: cupy.random.seed(seed) - if has_torch and torch.cuda.is_available(): + if has_torch and has_torch_cuda_gpu: torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -133,10 +110,18 @@ def is_torch_array(obj: Any) -> bool: # pragma: no cover return False -def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover +def is_torch_cuda_array(obj: Any) -> bool: # pragma: no cover return is_torch_array(obj) and obj.is_cuda +def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover + return is_torch_cuda_array(obj) or is_torch_mps_array(obj) + + +def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover + return is_torch_array(obj) and hasattr(obj, "is_mps") and obj.is_mps + + def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover if not has_tensorflow: return False @@ -174,17 +159,15 @@ def to_numpy(data): # pragma: no cover def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover """Set the current GPU device for cupy and torch (if available).""" - import cupy.cuda.device + if not has_cupy_gpu: + raise ValueError("No CUDA GPU devices detected") device = cupy.cuda.device.Device(gpu_id) device.use() - try: - import torch + if has_torch_cuda_gpu: torch.cuda.set_device(gpu_id) - torch.set_default_tensor_type("torch.cuda.FloatTensor") - except ImportError: - pass + return device @@ -194,30 +177,29 @@ def require_cpu() -> bool: # pragma: no cover ops = get_ops("cpu") set_current_ops(ops) - set_torch_tensor_type_for_ops(ops) return True def prefer_gpu(gpu_id: int = 0) -> bool: # pragma: no cover """Use GPU if it's available. Returns True if so, False otherwise.""" - from .backends.cupy_ops import CupyOps - - if CupyOps.xp is None: - return False - else: + if has_gpu: require_gpu(gpu_id=gpu_id) - return True + return has_gpu def require_gpu(gpu_id: int = 0) -> bool: # pragma: no cover - from .backends import set_current_ops, CupyOps + from .backends import set_current_ops, CupyOps, MPSOps - if CupyOps.xp is None: - raise ValueError("GPU is not accessible. Was the library installed correctly?") + if not has_gpu: + raise ValueError("No GPU devices detected") + + if has_cupy_gpu: + set_current_ops(CupyOps()) + set_active_gpu(gpu_id) + else: + set_current_ops(MPSOps()) - set_current_ops(CupyOps()) - set_active_gpu(gpu_id) return True @@ -237,16 +219,15 @@ def to_categorical( *, label_smoothing: float = 0.0, ) -> FloatsXd: - if not 0.0 <= label_smoothing < 0.5: - raise ValueError( - "label_smoothing should be greater or " - "equal to 0.0 and less than 0.5, " - f"but {label_smoothing} was provided." - ) if n_classes is None: n_classes = int(numpy.max(Y) + 1) # type: ignore + if label_smoothing < 0.0: + raise ValueError( + "Label-smoothing parameter has to be greater than or equal to 0" + ) + if label_smoothing == 0.0: if n_classes == 0: raise ValueError("n_classes should be at least 1") @@ -259,6 +240,14 @@ def to_categorical( ) nongold_prob = label_smoothing / (n_classes - 1) + max_smooth = (n_classes - 1) / n_classes + if n_classes > 1 and label_smoothing >= max_smooth: + raise ValueError( + f"For {n_classes} classes " + "label_smoothing parameter has to be less than " + f"{max_smooth}, but found {label_smoothing}." + ) + xp = get_array_module(Y) label_distr = xp.full((n_classes, n_classes), nongold_prob, dtype="float32") xp.fill_diagonal(label_distr, 1 - label_smoothing) @@ -359,17 +348,29 @@ def iterate_recursive(is_match: Callable[[Any], bool], obj: Any) -> Any: def xp2torch( - xp_tensor: ArrayXd, requires_grad: bool = False + xp_tensor: ArrayXd, + requires_grad: bool = False, + device: Optional["torch.device"] = None, ) -> "torch.Tensor": # pragma: no cover """Convert a numpy or cupy tensor to a PyTorch tensor.""" assert_pytorch_installed() + + if device is None: + device = get_torch_default_device() + if hasattr(xp_tensor, "toDlpack"): dlpack_tensor = xp_tensor.toDlpack() # type: ignore torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor) + elif hasattr(xp_tensor, "__dlpack__"): + torch_tensor = torch.utils.dlpack.from_dlpack(xp_tensor) else: torch_tensor = torch.from_numpy(xp_tensor) + + torch_tensor = torch_tensor.to(device) + if requires_grad: torch_tensor.requires_grad_() + return torch_tensor @@ -382,14 +383,14 @@ def torch2xp( from .api import NumpyOps assert_pytorch_installed() - if is_torch_gpu_array(torch_tensor): + if is_torch_cuda_array(torch_tensor): if isinstance(ops, NumpyOps): return torch_tensor.detach().cpu().numpy() else: - return cupy.fromDlpack(torch.utils.dlpack.to_dlpack(torch_tensor)) + return cupy_from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor)) else: if isinstance(ops, NumpyOps) or ops is None: - return torch_tensor.detach().numpy() + return torch_tensor.detach().cpu().numpy() else: return cupy.asarray(torch_tensor) @@ -401,7 +402,10 @@ def xp2tensorflow( assert_tensorflow_installed() if hasattr(xp_tensor, "toDlpack"): dlpack_tensor = xp_tensor.toDlpack() # type: ignore - tf_tensor = tensorflow.experimental.dlpack.from_dlpack(dlpack_tensor) + tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor) + elif hasattr(xp_tensor, "__dlpack__"): + dlpack_tensor = xp_tensor.__dlpack__() # type: ignore + tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor) else: tf_tensor = tf.convert_to_tensor(xp_tensor) if as_variable: @@ -430,8 +434,8 @@ def tensorflow2xp( if isinstance(ops, NumpyOps): return tf_tensor.numpy() else: - dlpack_tensor = tensorflow.experimental.dlpack.to_dlpack(tf_tensor) - return cupy.fromDlpack(dlpack_tensor) + dlpack_tensor = tf.experimental.dlpack.to_dlpack(tf_tensor) + return cupy_from_dlpack(dlpack_tensor) else: if isinstance(ops, NumpyOps) or ops is None: return tf_tensor.numpy() @@ -465,7 +469,7 @@ def mxnet2xp( if isinstance(ops, NumpyOps): return mx_tensor.detach().asnumpy() else: - return cupy.fromDlpack(mx_tensor.to_dlpack_for_write()) + return cupy_from_dlpack(mx_tensor.to_dlpack_for_write()) else: if isinstance(ops, NumpyOps) or ops is None: return mx_tensor.detach().asnumpy() @@ -579,22 +583,6 @@ def use_nvtx_range(message: str, id_color: int = -1): yield -def set_torch_tensor_type_for_ops(ops): - """Set the PyTorch default tensor type for the given ops. This is a - no-op if PyTorch is not available.""" - from .backends.cupy_ops import CupyOps - - try: - import torch - - if CupyOps.xp is not None and isinstance(ops, CupyOps): - torch.set_default_tensor_type("torch.cuda.FloatTensor") - else: - torch.set_default_tensor_type("torch.FloatTensor") - except ImportError: - pass - - @dataclass class ArrayInfo: """Container for info for checking array compatibility.""" @@ -619,6 +607,7 @@ def check_consistency(self, arr: ArrayXd): __all__ = [ "get_array_module", + "get_torch_default_device", "fix_random_seed", "is_cupy_array", "is_numpy_array", @@ -636,6 +625,5 @@ def check_consistency(self, arr: ArrayXd): "DataValidationError", "make_tempfile", "use_nvtx_range", - "set_torch_tensor_type_for_ops", "ArrayInfo", ] diff --git a/website/Dockerfile b/website/Dockerfile new file mode 100644 index 000000000..b1965b17a --- /dev/null +++ b/website/Dockerfile @@ -0,0 +1,16 @@ +FROM node:11.15.0 + +WORKDIR /thinc-ai + +RUN npm install -g gatsby-cli@2.7.4 + +COPY package.json . +COPY package-lock.json . + +RUN npm install + +# This is so the installed node_modules will be up one directory +# from where a user mounts files, so that they don't accidentally mount +# their own node_modules from a different build +# https://nodejs.org/api/modules.html#modules_loading_from_node_modules_folders +WORKDIR /thinc-ai/website/ diff --git a/website/README.md b/website/README.md index 8c8d53a98..f1c4ec5ba 100644 --- a/website/README.md +++ b/website/README.md @@ -14,6 +14,30 @@ npm run dev # start dev server A `.prettierrc` is included in the repo, so if you set up auto-formatting with Prettier, it should match the style. +## Build and run the website in a Docker container + +Rather than installing NPM locally, you can also build a Docker container with +the prerequisite dependencies: + +```bash +docker build -t thinc-ai . +``` + +Afterwards, the website can be built and run in the container: + +```bash +docker run --rm -it \ + -v $PWD:/thinc-ai/website \ + -p 8000:8000 \ + thinc-ai \ + gatsby develop -H 0.0.0.0 +``` + +This is currently the only way to build the website on ARM64 Macs, since the +required Node.js version is not built for macOS/ARM64. + +These commands also work with Podman by replacing `docker` by `podman`. + ## Directory structure - `/docs`: Docs pages as Markdown. diff --git a/website/docs/_quickstart.json b/website/docs/_quickstart.json index b3f258070..bed6629a8 100644 --- a/website/docs/_quickstart.json +++ b/website/docs/_quickstart.json @@ -12,12 +12,7 @@ { "label": "9.2", "value": "cuda92" }, { "label": "10.0", "value": "cuda100" }, { "label": "10.1", "value": "cuda101" }, - { "label": "11.0", "value": "cuda110" }, - { "label": "11.1", "value": "cuda111" }, - { "label": "11.2", "value": "cuda112" }, - { "label": "11.3", "value": "cuda113" }, - { "label": "11.4", "value": "cuda114" }, - { "label": "11.5", "value": "cuda115" } + { "label": "10.2, 11.0+", "value": "cuda-autodetect" } ] }, { diff --git a/website/docs/api-backends.md b/website/docs/api-backends.md index 7f5c1e8d7..c5a54cff8 100644 --- a/website/docs/api-backends.md +++ b/website/docs/api-backends.md @@ -382,6 +382,26 @@ the inputs and outputs. | `zeros` | bool | Fill the array with zeros (default: `True`). | | **RETURNS** | ArrayXd | An array of the correct shape and data type. | +### Ops.cblas {#cblas tag="method"} + + + +- **default:** +- **numpy:** +- **cupy:** + + + +Get a table of C BLAS functions usable in Cython `cdef nogil` functions. This +method does not take any arguments. + + + +This method is only supported by `NumpyOps`. A `NotImplementedError` exception +is raised when calling this method on `Ops` or `CupyOps`. + + + ### Ops.to_numpy {#to_numpy tag="method"} @@ -907,6 +927,47 @@ Backpropagate the Swish activation | `inplace` | bool | If `True`, the `dY` array is modified in place. | | **RETURNS** | FloatsXd | The gradient of the input. | +### Ops.dish {#dish tag="method" new="8.1.1"} + + + +- **default:** +- **numpy:** +- **cupy:** + + + +Dish or "Daniël's Swish-like activation" is an activation function with a non-monotinic shape similar to +[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on +elementary functions like `exp` or `erf`, making it much +[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853) +in most cases. + +| Argument | Type | Description | +| ----------- | ----------------- | ------------------------------------------ | +| `X` | FloatsXd | The inputs. | +| `inplace` | bool | If `True`, the array is modified in place. | +| **RETURNS** | FloatsXd | The outputs. | + +### Ops.backprop_dish {#backprop_dish tag="method" new="8.1.1"} + + + +- **default:** +- **numpy:** +- **cupy:** + + + +Backpropagate the Dish activation. + +| Argument | Type | Description | +| ----------- | ----------------- | ----------------------------------------------- | +| `dY` | FloatsXd | Gradients of the output array. | +| `X` | FloatsXd | The inputs to the forward pass. | +| `inplace` | bool | If `True`, the `dY` array is modified in place. | +| **RETURNS** | FloatsXd | The gradient of the input. | + ### Ops.gelu {#gelu tag="method"} @@ -1193,6 +1254,82 @@ Backpropagate the hard Swish MobileNet activation. | `inplace` | bool | If `True`, the `dY` array is modified in place. | | **RETURNS** | FloatsXd | The gradient of the input. | +### Ops.reduce_first {#reduce_first tag="method"} + + + +- **default:** +- **numpy:** default +- **cupy:** default + + + +Perform sequence-wise first pooling for data in the ragged format. Zero-length +sequences are not allowed. A `ValueError` is raised if any element in `lengths` +is zero. + +| Argument | Type | Description | +| ----------- | ------------------------------- | --------------------------------------------------------------------- | +| `X` | Floats2d | The concatenated sequences. | +| `lengths` | Ints1d | The sequence lengths. | +| **RETURNS** | Tuple[Floats2d,Ints1d] | The first vector of each sequence and the sequence start/end indices. | + +### Ops.backprop_reduce_first {#backprop_reduce_first tag="method"} + + + +- **default:** +- **numpy:** default +- **cupy:** default + + + +Backpropagate the `reduce_first` operation. + +| Argument | Type | Description | +| ------------- | ----------------- | ------------------------------------------- | +| `d_firsts` | Floats2d | The gradient of the outputs. | +| `starts_ends` | Ints1d | The sequence start/end indices. | +| **RETURNS** | Floats2d | The gradient of the concatenated sequences. | + +### Ops.reduce_last {#reduce_last tag="method"} + + + +- **default:** +- **numpy:** default +- **cupy:** default + + + +Perform sequence-wise last pooling for data in the ragged format. Zero-length +sequences are not allowed. A `ValueError` is raised if any element in `lengths` +is zero. + +| Argument | Type | Description | +| ----------- | ------------------------------- | ------------------------------------------------------------------------------- | +| `X` | Floats2d | The concatenated sequences. | +| `lengths` | Ints1d | The sequence lengths. | +| **RETURNS** | Tuple[Floats2d,Ints1d] | The last vector of each sequence and the indices of the last sequence elements. | + +### Ops.backprop_reduce_last {#backprop_reduce_last tag="method"} + + + +- **default:** +- **numpy:** default +- **cupy:** default + + + +Backpropagate the `reduce_last` operation. + +| Argument | Type | Description | +| ----------- | ----------------- | ------------------------------------------- | +| `d_lasts` | Floats2d | The gradient of the outputs. | +| `lasts` | Ints1d | Indices of the last sequence elements. | +| **RETURNS** | Floats2d | The gradient of the concatenated sequences. | + ### Ops.reduce_sum {#reduce_sum tag="method"} @@ -1278,8 +1415,8 @@ Backpropagate the `reduce_mean` operation. Perform sequence-wise max pooling for data in the ragged format. Zero-length -sequences are not allowed. A `ValueError` is raised if any element in -`lengths` is zero. +sequences are not allowed. A `ValueError` is raised if any element in `lengths` +is zero. | Argument | Type | Description | | ----------- | -------------------------------- | --------------------------- | @@ -1344,6 +1481,25 @@ Create hashed ngram features. | `keys` | Ints1d | The input sequence. | | **RETURNS** | Ints1d | The hashed ngrams. | +### Ops.gather_add {#gather_add tag="method" new="8.1"} + + + +- **default:** +- **numpy:** +- **cupy:** + + + +Gather rows from `table` with shape `(T, O)` using array `indices` with shape +`(B, K)`, then sum the resulting array with shape `(B, K, O)` over the `K` axis. + +| Argument | Type | Description | +| ----------- | ----------------- | ----------------------- | +| `table` | Floats2d | The array to increment. | +| `indices` | Ints2d | The indices to use. | +| **RETURNS** | Floats2d | The summed rows. | + ### Ops.scatter_add {#scatter_add tag="method"} diff --git a/website/docs/api-layers.md b/website/docs/api-layers.md index 1c43a9d7a..e204c4c46 100644 --- a/website/docs/api-layers.md +++ b/website/docs/api-layers.md @@ -44,6 +44,39 @@ Primarily used within [`siamese`](#siamese) neural networks. https://github.com/explosion/thinc/blob/master/thinc/layers/cauchysimilarity.py ``` +### Dish {#dish tag="function"} + + + +- **Input:** Floats2d +- **Output:** Floats2d +- **Parameters:** W, + b + + + +A dense layer with the Dish activation function. Dish or "Daniël's Swish-like +activation" is an activation function with a non-monotinic shape similar to +[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on +elementary functions like `exp` or `erf`, making it much +[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853) +in most cases. + +| Argument | Type | Description | +| -------------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------ | +| `nO` | Optional[int] | The size of the output vectors. | +| `nI` | Optional[int] | The size of the input vectors. | +| _keyword-only_ | | | +| `init_W` | Callable | A function to initialize the weights matrix. Defaults to [`he_normal_init`](/docs/api-initializers#he_normal_init) | +| `init_b` | Callable | A function to initialize the bias vector. Defaults to [`zero_init`](/docs/api-initializers#zero_init). | +| `dropout` | Optional[float] | Dropout rate to avoid overfitting. | +| `normalize` | bool | Whether or not to apply [layer normalization](#layernorm). Defaults to `False`. | +| **RETURNS** | Model[Floats2d, Floats2d] | The created dense layer. | + +```python +https://github.com/explosion/thinc/blob/master/thinc/layers/dish.py +``` + ### Dropout {#dropout tag="function"} @@ -835,8 +868,8 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_last.py Pooling layer that reduces the dimensions of the data by selecting the maximum -value for each feature. A `ValueError` is raised if any element in `lengths` -is zero. +value for each feature. A `ValueError` is raised if any element in `lengths` is +zero. | Argument | Type | Description | | ----------- | -------------------------------- | -------------------------- | @@ -1234,22 +1267,27 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/padded2list.py -- **Input:** Sequence[Any] +- **Input:** Union[Sequence[Hashable], Ints1d, Ints2d] - **Output:** Ints2d -Remap string or integer inputs using a mapping table, usually as a preprocess -before embeddings. The mapping table can be passed in on input, or updated after -the layer has been created. The mapping table is stored in the `"mapping_table"` -attribute. +Remap a sequence of strings, integers or other hashable inputs using a mapping +table, usually as a preprocessing step before embeddings. The input can also be +a two dimensional integer array in which case the `column` attribute tells the +`remap_ids` layer which column of the array to map with the `mapping_table`. +Both attributes can be passed on initialization, but since the layer is designed +to retrieve them from `model.attrs` during `forward`, they can be set any time +before calling `forward`. This means that they can also be changed between +calls. Before calling `forward` the `mapping_table` has to be set and for 2D +inputs the `column` is also required. -| Argument | Type | Description | -| --------------- | ------------------------------------- | ------------------------------------------------------------------------------------------------------------ | -| `mapping_table` | Dict[Any, int] | The mapping table to use. Can also be set after initialization by writing to `model.attrs["mapping_table"]`. | -| `default` | int | The default value if the input does not have an entry in the mapping table. | -| `dtype` | DTypes | The data type of the array. | -| **RETURNS** | Model[Sequence[Any], Ints2d] | The layer to compute the transformation. | +| Argument | Type | Description | +| --------------- | ----------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ | +| `mapping_table` | Dict[Any, int] | The mapping table to use. Can also be set after initialization by writing to `model.attrs["mapping_table"]`. | +| `default` | int | The default value if the input does not have an entry in the mapping table. | +| `column` | int | The column to apply the mapper to in case of 2D input. | +| **RETURNS** | Model[Union[Sequence[Hashable], Ints1d, Ints2d], Ints2d] | The layer to compute the transformation. | ```python https://github.com/explosion/thinc/blob/master/thinc/layers/remap_ids.py @@ -1531,6 +1569,45 @@ model.initialize() https://github.com/explosion/thinc/blob/master/thinc/layers/with_nvtx_range.py ``` +### with_signpost_interval {#with_signpost_interval tag="function" new="8.1.1"} + + + +- **Input:** Any +- **Output:** Any + + + +Layer that wraps any layer and marks the init, forward and backprop passes as a +(macOS) signpost interval. This can be helpful when profiling the performance of +a layer using macOS +[Instruments.app](https://help.apple.com/instruments/mac/current/). Use of this +layer requires that the +[`os-signpost`](https://github.com/explosion/os-signpost) package is installed. + +```python +### Example +from os_signpost import Signposter +from thinc.api import Linear, with_signpost_interval + +signposter = Signposter("com.example.my_subsystem", + Signposter.Category.DynamicTracing) + +model = with_signpost_interval(Linear(2, 5), signposter) +model.initialize() +``` + +| Argument | Type | Description | +| ------------ | --------------------------------- | ------------------------------------------------------------------------------- | +| `layer` | Model | The layer to wrap. | +| `signposter` | os_signposter.Signposter | `Signposter` object to log the interval with. | +| `name` | Optional[str] | Optional name for the wrapped layer. Defaults to the name of the wrapped layer. | +| **RETURNS** | Model | The wrapped layer. | + +```python +https://github.com/explosion/thinc/blob/master/thinc/layers/with_signpost_interval.py +``` + --- ## Wrappers {#wrappers} diff --git a/website/docs/api-loss.md b/website/docs/api-loss.md index 3720a78f5..5ccf84b58 100644 --- a/website/docs/api-loss.md +++ b/website/docs/api-loss.md @@ -70,9 +70,8 @@ are expected to be in the range of `[0, 1]`. They can both represent exclusive classes from multi-class cross-entropy (generally coming from a `softmax` layer) or could be classwise binary decisions for multi-label cross-entropy (`sigmoid` layer). The `truths` are most commonly provided as labels in `Ints1d`, -`List[int]` or `List[str]` format. -Alternatively, users can provide `truths` as a `Floats2d` for -example to encode label-confidences. +`List[int]` or `List[str]` format. Alternatively, users can provide `truths` as +a `Floats2d` for example to encode label-confidences. @@ -98,7 +97,7 @@ normalize = true | `normalize` | bool | Normalize and divide by number of examples given. | | `neg_prefix` | str | Prefix used to indicate that a label is negative e.g. "!sci-fi". | | `missing_value` | Union[str, int] | Specific label that indicates the value is missing and should not be considered for training/evaluation purposes, e.g. empty string `""` or `0`. | -| `label_smoothing` | float | Smoothing-coefficient for label-smoothing in range of [0, 0.5[. | +| `label_smoothing` | float | Smoothing-coefficient for label-smoothing. | ### SequenceCategoricalCrossentropy {#sequence_categorical_crossentropy tag="class"} @@ -138,7 +137,7 @@ normalize = true | `normalize` | bool | Normalize and divide by number of examples given. | | `neg_prefix` | str | Symbol that indicates that a label is negative e.g. "!sci-fi". | | `missing_value` | Union[str, int] | Symbol for "missing value" among the labels. | -| `label_smoothing` | float | Smoothing-coefficient for label-smoothing in range of [0, 0.5]. | +| `label_smoothing` | float | Smoothing-coefficient for label-smoothing. | ### L2Distance {#l2distance tag="class"} diff --git a/website/docs/api-util.md b/website/docs/api-util.md index e775fae78..add7c12e1 100644 --- a/website/docs/api-util.md +++ b/website/docs/api-util.md @@ -20,9 +20,9 @@ fix_random_seed(0) ### require_cpu {#require_cpu tag="function"} -Allocate data and perform operations on CPU. -If data has already been allocated on GPU, it will not be moved. -Ideally, this function should be called right after importing Thinc. +Allocate data and perform operations on CPU. If data has already been allocated +on GPU, it will not be moved. Ideally, this function should be called right +after importing Thinc. ```python ### Example @@ -69,7 +69,8 @@ require_gpu() ### set_active_gpu {#set_active_gpu tag="function"} -Set the current GPU device for `cupy` and `torch` (if available). +Set the current GPU device for `cupy` (and for `torch`, if installed) and return +a `cupy` device. Will raise an error if no GPU is available. ```python ### Example @@ -132,11 +133,13 @@ element). Converts a class vector (integers) to binary class matrix. Based on [`keras.utils.to_categorical`](https://keras.io/utils/). -| Argument | Type | Description | -| ----------- | ---------------------- | ---------------------------------------------------------------------------------------------- | -| `Y` | IntsXd | Class vector to be converted into a matrix (integers from `0` to `n_classes`). | -| `n_classes` | Optional[int] | Total number of classes. | -| **RETURNS** |  Floats2d | A binary matrix representation of the input. The axis representing the classes is placed last. | +| Argument | Type | Description | +| ----------------- | ---------------------- | ---------------------------------------------------------------------------------------------- | +| `Y` | IntsXd | Class vector to be converted into a matrix (integers from `0` to `n_classes`). | +| `n_classes` | Optional[int] | Total number of classes. | +| _keyword-only_ | | | +| `label_smoothing` | float | Smoothing-coefficient for label-smoothing. | +| **RETURNS** | Floats2d | A binary matrix representation of the input. The axis representing the classes is placed last. | ### xp2torch {#xp2torch tag="function"} @@ -165,7 +168,7 @@ Convert a `numpy` or `cupy` tensor to a TensorFlow tensor. | --------------- | -------------------------- | ----------------------------------------------------- | | `xp_tensor` | ArrayXd | The tensor to convert. | | `requires_grad` | bool | Whether to backpropagate through the variable. | -| `as_variable` | bool | Convert the result to a `tensorflow.Variable` object. | | +| `as_variable` | bool | Convert the result to a `tensorflow.Variable` object. | | **RETURNS** | tensorflow.Tensor | The converted tensor. | ### tensorflow2xp {#tensorflow2xp tag="function"} diff --git a/website/docs/install.md b/website/docs/install.md index 70ff24fc3..14f5d7bd7 100644 --- a/website/docs/install.md +++ b/website/docs/install.md @@ -47,9 +47,8 @@ is_gpu = prefer_gpu() ### Using build constraints when compiling from source If you install Thinc from source or with `pip` for platforms where there are not -binary wheels on PyPI (currently any non-`x86_64` platforms, so commonly Linux -`aarch64` or OS X M1/`arm64`), you may need to use build constraints if any -package in your environment requires an older version of `numpy`. +binary wheels on PyPI, you may need to use build constraints if any package +in your environment requires an older version of `numpy`. If `numpy` gets downgraded from the most recent release at any point after you've compiled `thinc`, you might see an error that looks like this: @@ -72,9 +71,9 @@ pip install thinc --no-cache-dir ``` Our build constraints currently specify the oldest supported `numpy` available -on PyPI for `x86_64`. Depending on your platform and environment, you may want -to customize the specific versions of `numpy`. For other platforms, you can have -a look at SciPy's +on PyPI for `x86_64` and `aarch64`. Depending on your platform and environment, +you may want to customize the specific versions of `numpy`. For other +platforms, you can have a look at SciPy's [`oldest-supported-numpy`](https://github.com/scipy/oldest-supported-numpy/blob/main/setup.cfg) package to see what the oldest recommended versions of `numpy` are. diff --git a/website/docs/usage-config.md b/website/docs/usage-config.md index abb6951e4..73a1638ac 100644 --- a/website/docs/usage-config.md +++ b/website/docs/usage-config.md @@ -12,15 +12,16 @@ And then once those settings are added, they become hard to remove later. Default values also become hard to change without breaking backwards compatibility. -To solve this problem, Thinc provides a config system that lets you easily -describe **arbitrary trees of objects**. The objects can be created via -**function calls you register** using a simple decorator syntax. You can even -version the functions you create, allowing you to make improvements without -breaking backwards compatibility. The most similar config system we're aware of -is [Gin](https://github.com/google/gin-config), which uses a similar syntax, and -also allows you to link the configuration system to functions in your code using -a decorator. Thinc's config system is simpler and emphasizes a different -workflow via a subset of Gin's functionality. +To solve this problem, Thinc leverages +[confection](https://github.com/explosion/confection) - a config system that +lets you easily describe **arbitrary trees of objects**. The objects can be +created via **function calls you register** using a simple decorator syntax. You +can even version the functions you create, allowing you to make improvements +without breaking backwards compatibility. The most similar config system we're +aware of is [Gin](https://github.com/google/gin-config), which uses a similar +syntax, and also allows you to link the configuration system to functions in +your code using a decorator. Thinc's config system is simpler and emphasizes a +different workflow via a subset of Gin's functionality. @@ -654,11 +655,11 @@ resolved = registry.resolve( The main motivation for Thinc's configuration system was to eliminate hidden defaults and ensure that config settings are passed around consistently. This also means that config files should always define **all available settings**. -The [`registry.fill`](/docs/api-config#registry-fill) method also -resolves the config, but it leaves references to registered functions intact and -doesn't replace them with their return values. If type annotations and/or a base -schema are available, they will be used to parse the config and fill in any -missing values and defaults to create an up-to-date "master config". +The [`registry.fill`](/docs/api-config#registry-fill) method also resolves the +config, but it leaves references to registered functions intact and doesn't +replace them with their return values. If type annotations and/or a base schema +are available, they will be used to parse the config and fill in any missing +values and defaults to create an up-to-date "master config". Let's say you've updated your schema and scripts to use two additional optional settings. These settings should also be reflected in your config files so they @@ -677,9 +678,9 @@ class TrainingSchema(BaseModel): max_epochs: StrictInt = 100 ``` -Calling [`registry.fill`](/docs/api-config#registry-fill) with your -existing config will produce an updated version of it including the new settings -and their defaults: +Calling [`registry.fill`](/docs/api-config#registry-fill) with your existing +config will produce an updated version of it including the new settings and +their defaults: diff --git a/website/docs/usage-sequences.md b/website/docs/usage-sequences.md index 47862ca50..72a4ed65c 100644 --- a/website/docs/usage-sequences.md +++ b/website/docs/usage-sequences.md @@ -146,9 +146,7 @@ of your network. ```python ### Example -from thinc.api import Ragged - -from thinc.api import get_current_ops, Ragged, Linear +from thinc.api import get_current_ops, Ragged, Linear, list2ragged ops = get_current_ops() sequences = [ @@ -156,14 +154,15 @@ sequences = [ ops.alloc2f(2, 5) + 2, ops.alloc2f(4, 5) + 3, ] -ragged = ops.list2ragged(sequences) +list2ragged_model = list2ragged() +ragged = list2ragged_model.predict(sequences) assert ragged.data.shape == (13, 5) # This will always be true: assert ragged.data.shape[0] == ragged.lengths.sum() # Data from sequence 0 is in the first 7 rows, followed by seqs 1 and 2 -assert ragged.data[:7] == 1 -assert ragged.data[7:2] == 2 -assert ragged.data[9:] == 3 +assert (ragged.data[:7] == 1).all() +assert (ragged.data[7:2] == 2).all() +assert (ragged.data[9:] == 3).all() # Indexing gets the batch item, and returns a Ragged object ragged[0].data.shape == (7, 5) # You can pass the data straight into dense layers @@ -173,7 +172,7 @@ ragged_out = Ragged(output, ragged.lengths) # Internally, data is reshaped to 2d. The original shape is accessible at the # the dataXd property. sequences3d = [ops.alloc3f(5, 6, 7), ops.alloc3f(10, 6, 7)] -ragged3d = ops.list2ragged(sequences3d) +ragged3d = list2ragged_model.predict(sequences3d) ragged3d.data.shape == (15, 13) ragged3d.dataXd.shape == (15, 6, 7) ``` diff --git a/website/src/pages/index.js b/website/src/pages/index.js index 037d9e3fc..2318fc8b2 100644 --- a/website/src/pages/index.js +++ b/website/src/pages/index.js @@ -20,14 +20,10 @@ export default () => ( from the makers of
spaCy - - ,{' '} - - Prodigy {' '} &{' '} - - FastAPI + + Prodigy