Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pyarrow as pa

from . import config, utils
from .features import Features, _ArrayXDExtensionType
from .features import Features, _ArrayXDExtensionType, numpy_to_pyarrow_listarray
from .info import DatasetInfo
from .keyhash import DuplicatedKeysError, KeyHasher
from .utils import logging
Expand Down Expand Up @@ -86,14 +87,22 @@ def __arrow_array__(self, type=None):
"""This function is called when calling pa.array(typed_sequence)"""
assert type is None, "TypedSequence is supposed to be used with pa.array(typed_sequence, type=None)"
trying_type = False
if type is None and self.try_type:
if type is not None: # user explicitly passed the feature
pass
elif type is None and self.try_type:
type = self.try_type
trying_type = True
else:
type = self.type
try:
if isinstance(type, _ArrayXDExtensionType):
out = pa.ExtensionArray.from_storage(type, pa.array(self.data, type.storage_dtype))
if isinstance(self.data, np.ndarray):
storage = numpy_to_pyarrow_listarray(self.data, type=type.value_type)
else:
storage = pa.array(self.data, type.storage_dtype)
out = pa.ExtensionArray.from_storage(type, storage)
elif isinstance(self.data, np.ndarray):
out = numpy_to_pyarrow_listarray(self.data)
else:
out = pa.array(self.data, type=type)
if trying_type and out[0].as_py() != self.data[0]:
Expand All @@ -111,8 +120,11 @@ def __arrow_array__(self, type=None):
return out
except (TypeError, pa.lib.ArrowInvalid) as e: # handle type errors and overflows
if trying_type:
try:
return pa.array(self.data, type=None) # second chance
try: # second chance
if isinstance(self.data, np.ndarray):
return numpy_to_pyarrow_listarray(self.data, type=None)
else:
return pa.array(self.data, type=None)
except pa.lib.ArrowInvalid as e:
if "overflow" in str(e):
raise OverflowError(
Expand Down
23 changes: 18 additions & 5 deletions src/datasets/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import sys
from collections.abc import Iterable
from dataclasses import dataclass, field, fields
from functools import reduce
from operator import mul
from typing import Any, ClassVar, Dict, List, Optional
from typing import Sequence as Sequence_
from typing import Tuple, Union
Expand Down Expand Up @@ -144,7 +146,7 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType:

def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
"""
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
It works recursively.

To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
Expand All @@ -169,13 +171,13 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
import jax.numpy as jnp

if isinstance(obj, np.ndarray):
return obj.tolist(), True
return obj.tolist(), False
elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor):
return obj.detach().cpu().numpy().tolist(), True
return obj.detach().cpu().numpy(), True
elif config.TF_AVAILABLE and "tensorflow" in sys.modules and isinstance(obj, tf.Tensor):
return obj.numpy().tolist(), True
return obj.numpy(), True
elif config.JAX_AVAILABLE and "jax" in sys.modules and isinstance(obj, jnp.ndarray):
return obj.tolist(), True
return np.asarray(obj), True
elif isinstance(obj, pd.Series):
return obj.values.tolist(), True
elif isinstance(obj, pd.DataFrame):
Expand Down Expand Up @@ -950,6 +952,17 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
raise ValueError(f"Cannot convert {pa_type} to a Feature type.")


def numpy_to_pyarrow_listarray(arr: np.ndarray, type: pa.DataType = None) -> pa.ListArray:
"""Build a PyArrow ListArray from a multidimensional NumPy array"""
values = pa.array(arr.flatten(), type=type)
for i in range(arr.ndim - 1):
n_offsets = reduce(mul, arr.shape[: arr.ndim - i - 1], 1)
step_offsets = arr.shape[arr.ndim - i - 1]
offsets = pa.array(np.arange(n_offsets + 1) * step_offsets, type=pa.int32())
values = pa.ListArray.from_arrays(offsets, values)
return values


class Features(dict):
@property
def type(self):
Expand Down
22 changes: 20 additions & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .utils import (
assert_arrow_memory_doesnt_increase,
assert_arrow_memory_increases,
require_jax,
require_pyarrow_at_least_3,
require_s3,
require_tf,
Expand Down Expand Up @@ -998,7 +999,7 @@ def func(example):
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"filename": Value("string"), "tensor": Sequence(Value("float64"))}),
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
)
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])

Expand All @@ -1015,7 +1016,24 @@ def func(example):
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"filename": Value("string"), "tensor": Sequence(Value("float64"))}),
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
)
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])

@require_jax
def test_map_jax(self, in_memory):
import jax.numpy as jnp

def func(example):
return {"tensor": jnp.asarray([1.0, 2, 3])}

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
)
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])

Expand Down
54 changes: 45 additions & 9 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,30 @@ def test_classlabel_int2str():
classlabel.int2str(len(names))


def iternumpy(key1, value1, value2):
if value1.dtype != value2.dtype: # check only for dtype
raise AssertionError(
f"dtype of '{key1}' key for casted object: {value1.dtype} and expected object: {value2.dtype} not matching"
)


def dict_diff(d1: dict, d2: dict): # check if 2 dictionaries are equal

np.testing.assert_equal(d1, d2) # sanity check if dict values are equal or not

for (k1, v1), (k2, v2) in zip(d1.items(), d2.items()): # check if their values have same dtype or not
if isinstance(v1, dict): # nested dictionary case
dict_diff(v1, v2)
elif isinstance(v1, np.ndarray): # checks if dtype and value of np.ndarray is equal
iternumpy(k1, v1, v2)
elif isinstance(v1, list):
for (element1, element2) in zip(v1, v2): # iterates over all elements of list
if isinstance(element1, dict):
dict_diff(element1, element2)
elif isinstance(element1, np.ndarray):
iternumpy(k1, element1, element2)


class CastToPythonObjectsTest(TestCase):
def test_cast_to_python_objects_list(self):
obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
Expand All @@ -249,11 +273,14 @@ def test_cast_to_python_objects_tuple(self):
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)

def test_cast_to_python_objects_numpy(self):
def test_cast_to_python_or_numpy(self):
obj = {"col_1": [{"vec": np.arange(1, 4), "txt": "foo"}] * 3, "col_2": np.arange(1, 7).reshape(3, 2)}
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
expected_obj = {
"col_1": [{"vec": np.array([1, 2, 3]), "txt": "foo"}] * 3,
"col_2": np.array([[1, 2], [3, 4], [5, 6]]),
}
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)
dict_diff(casted_obj, expected_obj)

def test_cast_to_python_objects_series(self):
obj = {
Expand All @@ -278,9 +305,12 @@ def test_cast_to_python_objects_torch(self):
"col_1": [{"vec": torch.tensor(np.arange(1, 4)), "txt": "foo"}] * 3,
"col_2": torch.tensor(np.arange(1, 7).reshape(3, 2)),
}
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
expected_obj = {
"col_1": [{"vec": np.array([1, 2, 3]), "txt": "foo"}] * 3,
"col_2": np.array([[1, 2], [3, 4], [5, 6]]),
}
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)
dict_diff(casted_obj, expected_obj)

@require_tf
def test_cast_to_python_objects_tf(self):
Expand All @@ -290,9 +320,12 @@ def test_cast_to_python_objects_tf(self):
"col_1": [{"vec": tf.constant(np.arange(1, 4)), "txt": "foo"}] * 3,
"col_2": tf.constant(np.arange(1, 7).reshape(3, 2)),
}
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
expected_obj = {
"col_1": [{"vec": np.array([1, 2, 3]), "txt": "foo"}] * 3,
"col_2": np.array([[1, 2], [3, 4], [5, 6]]),
}
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)
dict_diff(casted_obj, expected_obj)

@require_jax
def test_cast_to_python_objects_jax(self):
Expand All @@ -302,9 +335,12 @@ def test_cast_to_python_objects_jax(self):
"col_1": [{"vec": jnp.array(np.arange(1, 4)), "txt": "foo"}] * 3,
"col_2": jnp.array(np.arange(1, 7).reshape(3, 2)),
}
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
expected_obj = {
"col_1": [{"vec": np.array([1, 2, 3]), "txt": "foo"}] * 3,
"col_2": np.array([[1, 2], [3, 4], [5, 6]]),
}
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)
dict_diff(casted_obj, expected_obj)

@patch("datasets.features._cast_to_python_objects", side_effect=_cast_to_python_objects)
def test_dont_iterate_over_each_element_in_a_list(self, mocked_cast):
Expand Down