Skip to content

Commit a9fe1af

Browse files
committed
CupyOps: Simplify asarray (explosion#661)
* `CupyOps`: Simplify `asarray` * Remove `cast_array` flag and use `astype` unconditionally * Revert unconditional call to `astype` * Remove no-op
1 parent abe1eb4 commit a9fe1af

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

thinc/backends/cupy_ops.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,29 +72,20 @@ def gemm(self, x, y, out=None, trans1=False, trans2=False):
7272
return out
7373

7474
def asarray(self, data, dtype=None):
75-
# This is sort of frustrating, but we can't easily otherwise pass
76-
# forward "unset".
77-
dtype = {"dtype": dtype} if dtype is not None else {}
78-
7975
# We'll try to perform a zero-copy conversion if possible.
80-
array = None
81-
cast_array = False
82-
if isinstance(data, cupy.ndarray):
83-
array = self.xp.asarray(data, **dtype)
84-
elif is_torch_array(data) and data.device.type == "cuda":
76+
if is_cupy_array(data):
77+
array = data
78+
elif is_torch_gpu_array(data):
8579
array = torch2xp(data)
86-
cast_array = True
87-
elif is_tensorflow_array(data) and "GPU:" in data.device:
80+
elif is_tensorflow_gpu_array(data):
8881
array = tensorflow2xp(data)
89-
cast_array = True
90-
elif is_mxnet_array(data) and data.context.device_type != "cpu":
82+
elif is_mxnet_gpu_array(data):
9183
array = mxnet2xp(data)
92-
cast_array = True
9384
else:
94-
array = self.xp.array(data, **dtype)
85+
array = self.xp.array(data)
9586

96-
if cast_array and dtype != {}:
97-
array = array.astype(dtype["dtype"])
87+
if dtype is not None:
88+
array = array.astype(dtype=dtype, copy=False)
9889

9990
return array
10091

0 commit comments

Comments
 (0)