diff --git a/thinc/util.py b/thinc/util.py index cdfdf4c1e..6ace7e7ed 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -314,6 +314,8 @@ def xp2torch( if hasattr(xp_tensor, "toDlpack"): dlpack_tensor = xp_tensor.toDlpack() # type: ignore torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor) + elif hasattr(xp_tensor, "__dlpack__"): + torch_tensor = torch.utils.dlpack.from_dlpack(xp_tensor) else: torch_tensor = torch.from_numpy(xp_tensor) if requires_grad: @@ -350,6 +352,8 @@ def xp2tensorflow( if hasattr(xp_tensor, "toDlpack"): dlpack_tensor = xp_tensor.toDlpack() # type: ignore tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor) + elif hasattr(xp_tensor, "__dlpack__"): + tf_tensor = tf.experimental.dlpack.from_dlpack(xp_tensor) else: tf_tensor = tf.convert_to_tensor(xp_tensor) if as_variable: