diff --git a/thinc/tests/test_util.py b/thinc/tests/test_util.py index 715d381d5..76985b24b 100644 --- a/thinc/tests/test_util.py +++ b/thinc/tests/test_util.py @@ -3,11 +3,19 @@ from hypothesis import given from thinc.api import get_width, Ragged, Padded from thinc.util import get_array_module, is_numpy_array, to_categorical +from thinc.util import is_cupy_array from thinc.util import convert_recursive from thinc.types import ArgsKwargs from . import strategies +ALL_XP = [numpy] +try: + import cupy + ALL_XP.append(cupy) +except ImportError: + pass + @pytest.mark.parametrize( "obj,width", @@ -39,11 +47,23 @@ def test_get_width_fail(obj): get_width(obj) -def test_array_module_cpu_gpu_helpers(): - xp = get_array_module(0) - assert hasattr(xp, "ndarray") - assert is_numpy_array(numpy.zeros((1, 2))) - assert not is_numpy_array((1, 2)) +@pytest.mark.parametrize("xp", ALL_XP) +def test_array_module_cpu_gpu_helpers(xp): + error = ("Only numpy and cupy arrays are supported" + ", but found instead. If " + "get_array_module module wasn't called " + "directly, this might indicate a bug in Thinc.") + with pytest.raises(ValueError, match=error): + get_array_module(0) + zeros = xp.zeros((1, 2)) + xp_ = get_array_module(zeros) + assert xp_ == xp + if xp == numpy: + assert is_numpy_array(zeros) + assert not is_numpy_array((1, 2)) + else: + assert is_cupy_array(zeros) + assert not is_cupy_array((1, 2)) @given( diff --git a/thinc/util.py b/thinc/util.py index e46c62447..ab2680dd1 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -47,10 +47,17 @@ def get_torch_default_device() -> "torch.device": def get_array_module(arr): # pragma: no cover - if is_cupy_array(arr): + if is_numpy_array(arr): + return numpy + elif is_cupy_array(arr): return cupy else: - return numpy + raise ValueError( + "Only numpy and cupy arrays are supported" + f", but found {type(arr)} instead. If " + "get_array_module module wasn't called " + "directly, this might indicate a bug in Thinc." + ) def gpu_is_available():