Skip to content
3 changes: 2 additions & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from .util import DataValidationError, data_validation
from .util import to_categorical, get_width, get_array_module, to_numpy
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, has_cupy, set_gpu_allocator
from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
Expand Down
9 changes: 3 additions & 6 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import threading

from .ops import Ops
from .cupy_ops import CupyOps, has_cupy
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from .. import registry
from ..compat import cupy, has_cupy


context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None)
Expand Down Expand Up @@ -46,8 +47,6 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
We'd like to support routing Tensorflow memory allocation via PyTorch as well
(or vice versa), but do not currently have an implementation for it.
"""
import cupy.cuda

assert_pytorch_installed()
pools = context_pools.get()
if "pytorch" not in pools:
Expand All @@ -65,8 +64,6 @@ def use_tensorflow_for_gpu_memory() -> None: # pragma: no cover
We'd like to support routing PyTorch memory allocation via Tensorflow as
well (or vice versa), but do not currently have an implementation for it.
"""
import cupy.cuda

assert_tensorflow_installed()
pools = context_pools.get()
if "tensorflow" not in pools:
Expand Down Expand Up @@ -94,7 +91,7 @@ def get_ops(name: str, **kwargs) -> Ops:

cls: Optional[Callable[..., Ops]] = None
if name == "cpu":
_import_extra_cpu_backends()
_import_extra_cpu_backends()
cls = ops_by_name.get("numpy")
cls = ops_by_name.get("apple", cls)
cls = ops_by_name.get("bigendian", cls)
Expand Down
25 changes: 5 additions & 20 deletions thinc/backends/_cupy_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,7 @@

from ..types import ArrayXd
from ..util import tensorflow2xp

try:
import tensorflow
except ImportError:
pass

try:
import torch
except ImportError:
pass

try:
from cupy.cuda.memory import MemoryPointer
from cupy.cuda.memory import UnownedMemory
except ImportError:
pass
from ..compat import torch, cupy, tensorflow


def cupy_tensorflow_allocator(size_in_bytes: int):
Expand All @@ -32,9 +17,9 @@ def cupy_tensorflow_allocator(size_in_bytes: int):
cupy_array = cast(ArrayXd, tensorflow2xp(tensor))
address = int(cupy_array.data)
# cupy has a neat class to help us here. Otherwise it will try to free.
memory = UnownedMemory(address, size_in_bytes, cupy_array)
memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, cupy_array)
# Now return a new memory pointer.
return MemoryPointer(memory, 0)
return cupy.cuda.memory.MemoryPointer(memory, 0)


def cupy_pytorch_allocator(size_in_bytes: int):
Expand All @@ -53,6 +38,6 @@ def cupy_pytorch_allocator(size_in_bytes: int):
# cupy has a neat class to help us here. Otherwise it will try to free.
# I think this is a private API? It's not in the types.
address = torch_tensor.data_ptr() # type: ignore
memory = UnownedMemory(address, size_in_bytes, torch_tensor)
memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, torch_tensor)
# Now return a new memory pointer.
return MemoryPointer(memory, 0)
return cupy.cuda.memory.MemoryPointer(memory, 0)
12 changes: 4 additions & 8 deletions thinc/backends/_custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
import re
from pathlib import Path
from collections import defaultdict

try:
import cupy
except ImportError:
cupy = None
from ..compat import cupy, has_cupy_gpu


PWD = Path(__file__).parent
Expand Down Expand Up @@ -55,7 +51,7 @@
cupy.RawModule(
code=KERNELS_SRC, options=("--std=c++11",), name_expressions=KERNELS_LIST
)
if cupy is not None
if has_cupy_gpu
else None
)

Expand All @@ -70,7 +66,7 @@ def _get_kernel(name):


def compile_mmh(src):
if cupy is None:
if not has_cupy_gpu:
return None
return cupy.RawKernel(src, "hash_data")

Expand Down Expand Up @@ -672,7 +668,7 @@ def _check_which_maxout(which, B: int, I: int, P: int):
"true",
"within_range",
)
if cupy is not None
if has_cupy_gpu
else None
)

Expand Down
16 changes: 1 addition & 15 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,4 @@
import numpy

try:
import cupy
import cupyx
import cupy.cuda
from cupy.cuda.compiler import compile_with_cache # noqa: F401

has_cupy = True

# We no longer have to set up the memory pool, fortunately.
except ImportError:
cupy = None
cupyx = None
has_cupy = False

from .. import registry
from .ops import Ops
from .numpy_ops import NumpyOps
Expand All @@ -22,6 +7,7 @@
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_cupy_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..compat import cupy, cupyx


@registry.ops("CupyOps")
Expand Down
70 changes: 70 additions & 0 deletions thinc/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from packaging.version import Version

try: # pragma: no cover
import cupy
import cupyx

has_cupy = True
cupy_version = Version(cupy.__version__)
try:
cupy.cuda.runtime.getDeviceCount()
has_cupy_gpu = True
except cupy.cuda.runtime.CUDARuntimeError:
has_cupy_gpu = False

if cupy_version.major >= 10:
# fromDlpack was deprecated in v10.0.0.
cupy_from_dlpack = cupy.from_dlpack
else:
cupy_from_dlpack = cupy.fromDlpack
except (ImportError, AttributeError):
cupy = None
cupyx = None
cupy_version = Version("0.0.0")
has_cupy = False
cupy_from_dlpack = None
has_cupy_gpu = False


try: # pragma: no cover
import torch.utils.dlpack
import torch

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
and not torch.cuda.amp.common.amp_definitely_not_available()
)
except ImportError: # pragma: no cover
torch = None
has_torch = False
has_torch_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

try: # pragma: no cover
import tensorflow.experimental.dlpack
import tensorflow

has_tensorflow = True
has_tensorflow_gpu = len(tensorflow.config.get_visible_devices("GPU")) > 0
except ImportError: # pragma: no cover
tensorflow = None
has_tensorflow = False
has_tensorflow_gpu = False


try: # pragma: no cover
import mxnet

has_mxnet = True
except ImportError: # pragma: no cover
mxnet = None
has_mxnet = False

try:
import h5py
except ImportError: # pragma: no cover
h5py = None
6 changes: 1 addition & 5 deletions thinc/layers/tensorflowwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from ..util import xp2tensorflow, tensorflow2xp, assert_tensorflow_installed
from ..util import is_tensorflow_array, convert_recursive, is_xp_array
from ..types import ArrayXd, ArgsKwargs

try:
import tensorflow as tf
except ImportError: # pragma: no cover
pass
from ..compat import tensorflow as tf

InT = TypeVar("InT")
OutT = TypeVar("OutT")
Expand Down
12 changes: 3 additions & 9 deletions thinc/shims/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@
import srsly
import copy

try:
import mxnet.autograd
import mxnet.optimizer
import mxnet as mx
except ImportError: # pragma: no cover
pass

from ..util import mxnet2xp, convert_recursive, make_tempfile, xp2mxnet
from ..util import get_array_module
from ..optimizers import Optimizer
from ..types import ArgsKwargs, FloatsXd
from .shim import Shim
from ..compat import mxnet as mx


class MXNetShim(Shim):
Expand All @@ -33,7 +27,7 @@ def predict(self, inputs: ArgsKwargs) -> Any:
evaluation mode.
"""
mx.autograd.set_training(train_mode=False)
with mxnet.autograd.pause():
with mx.autograd.pause():
outputs = self._model(*inputs.args, **inputs.kwargs)
mx.autograd.set_training(train_mode=True)
return outputs
Expand All @@ -50,7 +44,7 @@ def begin_update(self, inputs: ArgsKwargs):

def backprop(grads):
mx.autograd.set_recording(False)
mxnet.autograd.backward(*grads.args, **grads.kwargs)
mx.autograd.backward(*grads.args, **grads.kwargs)
return convert_recursive(
lambda x: hasattr(x, "grad"), lambda x: x.grad, inputs
)
Expand Down
14 changes: 3 additions & 11 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@
import itertools
import srsly

try:
import torch.autograd
from torch.cuda import amp
import torch.optim
import torch
except ImportError: # pragma: no cover
pass

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import has_torch_amp
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
from ..optimizers import Optimizer
Expand Down Expand Up @@ -73,7 +65,7 @@ def predict(self, inputs: ArgsKwargs) -> Any:
"""
self._model.eval()
with torch.no_grad():
with amp.autocast(self._mixed_precision):
with torch.cuda.amp.autocast(self._mixed_precision):
outputs = self._model(*inputs.args, **inputs.kwargs)
self._model.train()
return outputs
Expand All @@ -87,7 +79,7 @@ def begin_update(self, inputs: ArgsKwargs):
self._model.train()

# Note: mixed-precision autocast must not be applied to backprop.
with amp.autocast(self._mixed_precision):
with torch.cuda.amp.autocast(self._mixed_precision):
output = self._model(*inputs.args, **inputs.kwargs)

def backprop(grads):
Expand Down
8 changes: 2 additions & 6 deletions thinc/shims/pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Dict, Iterable, List, Union, cast

from ..util import has_torch_amp, is_torch_array

try:
import torch
except ImportError: # pragma: no cover
pass
from ..compat import has_torch_amp, torch
from ..util import is_torch_array


class PyTorchGradScaler:
Expand Down
17 changes: 2 additions & 15 deletions thinc/shims/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,8 @@
from ..types import ArgsKwargs, ArrayXd
from ..util import get_array_module
from .shim import Shim

try:
import cupy
except ImportError:
cupy = None

try:
import tensorflow as tf
except ImportError: # pragma: no cover
pass

try:
import h5py
except ImportError: # pragma: no cover
pass
from ..compat import tensorflow as tf
from ..compat import cupy, h5py

keras_model_fns = catalogue.create("thinc", "keras", entry_points=True)

Expand Down
3 changes: 2 additions & 1 deletion thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from packaging.version import Version
from thinc.api import NumpyOps, CupyOps, Ops, get_ops
from thinc.api import get_current_ops, use_ops
from thinc.util import has_torch, torch2xp, xp2torch, torch_version, gpu_is_available
from thinc.util import torch2xp, xp2torch, gpu_is_available
from thinc.compat import has_torch, torch_version
from thinc.api import fix_random_seed
from thinc.api import LSTM
from thinc.types import Floats2d
Expand Down
2 changes: 1 addition & 1 deletion thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from thinc.backends import NumpyOps
from thinc.util import data_validation, get_width
from thinc.types import Ragged, Padded, Array2d, Floats2d, FloatsXd, Shape
from thinc.util import has_torch
from thinc.compat import has_torch
import numpy
import pytest

Expand Down
2 changes: 1 addition & 1 deletion thinc/tests/layers/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import timeit
from thinc.api import NumpyOps, LSTM, PyTorchLSTM, with_padded, fix_random_seed
from thinc.api import Ops
from thinc.util import has_torch
from thinc.compat import has_torch
import pytest


Expand Down
Loading