Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def get_torch_default_device() -> "torch.device":


def get_array_module(arr): # pragma: no cover
if is_cupy_array(arr):
if is_numpy_array(arr):
return numpy
elif is_cupy_array(arr):
return cupy
else:
return numpy
return None


def gpu_is_available():
Expand Down