Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from . import config, utils
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .features import ClassLabel, Features, Value, cast_to_python_objects
from .features import ClassLabel, Features, Value
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .fingerprint import (
fingerprint_transform,
Expand Down Expand Up @@ -449,8 +449,6 @@ def from_dict(
info.features = features
if features is not None:
mapping = features.encode_batch(mapping)
else:
mapping = cast_to_python_objects(mapping)
mapping = {
col: OptimizedTypedSequence(data, type=features.type[col].type if features is not None else None, col=col)
for col, data in mapping.items()
Expand Down Expand Up @@ -2037,7 +2035,6 @@ def init_buffer_and_writer():
if isinstance(example, pa.Table):
writer.write_row(example)
else:
example = cast_to_python_objects(example)
writer.write(example)
else:
for i in pbar:
Expand Down Expand Up @@ -2065,7 +2062,6 @@ def init_buffer_and_writer():
if isinstance(batch, pa.Table):
writer.write_table(batch)
else:
batch = cast_to_python_objects(batch)
writer.write_batch(batch)
if update_data and writer is not None:
writer.finalize() # close_stream=bool(buf_writer is None)) # We only close if we are writing in a file
Expand Down
12 changes: 10 additions & 2 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
import pyarrow as pa

from . import config, utils
from .features import Features, _ArrayXDExtensionType, numpy_to_pyarrow_listarray
from .features import (
Features,
_ArrayXDExtensionType,
cast_to_python_objects,
list_of_np_array_to_pyarrow_listarray,
numpy_to_pyarrow_listarray,
)
from .info import DatasetInfo
from .keyhash import DuplicatedKeysError, KeyHasher
from .utils import logging
Expand Down Expand Up @@ -103,8 +109,10 @@ def __arrow_array__(self, type=None):
out = pa.ExtensionArray.from_storage(type, storage)
elif isinstance(self.data, np.ndarray):
out = numpy_to_pyarrow_listarray(self.data)
elif isinstance(self.data, list) and self.data and isinstance(self.data[0], np.ndarray):
out = list_of_np_array_to_pyarrow_listarray(self.data)
else:
out = pa.array(self.data, type=type)
out = pa.array(cast_to_python_objects(self.data, only_1d_for_numpy=True), type=type)
if trying_type and out[0].as_py() != self.data[0]:
raise TypeError(
"Specified try_type alters data. Please check that the type/feature that you provided match the type/features of the data."
Expand Down
53 changes: 43 additions & 10 deletions src/datasets/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType:
return pa.__dict__[arrow_data_factory_function_name]()


def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> Tuple[Any, bool]:
"""
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
It works recursively.
Expand All @@ -155,6 +155,9 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:

Args:
obj: the object (nested struct) to cast
only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as multi-dim numpy arrays, or convert them to
nested lists of 1-dimensional numpy arrays. This can be useful to keep only 1-d arrays to instantiate Arrow arrays.
Indeed Arrow only support converting 1-dimensional array values.

Returns:
casted_obj: the casted object
Expand All @@ -171,13 +174,27 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
import jax.numpy as jnp

if isinstance(obj, np.ndarray):
return obj.tolist(), False
if not only_1d_for_numpy or obj.ndim == 1:
return obj, False
else:
return [_cast_to_python_objects(x, only_1d_for_numpy=only_1d_for_numpy)[0] for x in obj], True
elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor):
return obj.detach().cpu().numpy(), True
if not only_1d_for_numpy or obj.ndim == 1:
return obj.detach().cpu().numpy(), True
else:
return [
_cast_to_python_objects(x, only_1d_for_numpy=only_1d_for_numpy)[0] for x in obj.detach().cpu().numpy()
], True
elif config.TF_AVAILABLE and "tensorflow" in sys.modules and isinstance(obj, tf.Tensor):
return obj.numpy(), True
if not only_1d_for_numpy or obj.ndim == 1:
return obj.numpy(), True
else:
return [_cast_to_python_objects(x, only_1d_for_numpy=only_1d_for_numpy)[0] for x in obj.numpy()], True
elif config.JAX_AVAILABLE and "jax" in sys.modules and isinstance(obj, jnp.ndarray):
return np.asarray(obj), True
if not only_1d_for_numpy or obj.ndim == 1:
return np.asarray(obj), True
else:
return [_cast_to_python_objects(x, only_1d_for_numpy=only_1d_for_numpy)[0] for x in np.asarray(obj)], True
elif isinstance(obj, pd.Series):
return obj.values.tolist(), True
elif isinstance(obj, pd.DataFrame):
Expand All @@ -186,7 +203,7 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
output = {}
has_changed = False
for k, v in obj.items():
casted_v, has_changed_v = _cast_to_python_objects(v)
casted_v, has_changed_v = _cast_to_python_objects(v, only_1d_for_numpy=only_1d_for_numpy)
has_changed |= has_changed_v
output[k] = casted_v
return output if has_changed else obj, has_changed
Expand All @@ -195,9 +212,11 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
for first_elmt in obj:
if first_elmt is not None:
break
casted_first_elmt, has_changed_first_elmt = _cast_to_python_objects(first_elmt)
casted_first_elmt, has_changed_first_elmt = _cast_to_python_objects(
first_elmt, only_1d_for_numpy=only_1d_for_numpy
)
if has_changed_first_elmt:
return [_cast_to_python_objects(elmt)[0] for elmt in obj], True
return [_cast_to_python_objects(elmt, only_1d_for_numpy=only_1d_for_numpy)[0] for elmt in obj], True
else:
if isinstance(obj, list):
return obj, False
Expand All @@ -209,7 +228,7 @@ def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
return obj, False


def cast_to_python_objects(obj: Any) -> Any:
def cast_to_python_objects(obj: Any, only_1d_for_numpy=False) -> Any:
"""
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
It works recursively.
Expand All @@ -224,7 +243,7 @@ def cast_to_python_objects(obj: Any) -> Any:
Returns:
casted_obj: the casted object
"""
return _cast_to_python_objects(obj)[0]
return _cast_to_python_objects(obj, only_1d_for_numpy=only_1d_for_numpy)[0]


@dataclass
Expand Down Expand Up @@ -963,6 +982,20 @@ def numpy_to_pyarrow_listarray(arr: np.ndarray, type: pa.DataType = None) -> pa.
return values


def list_of_pa_arrays_to_pyarrow_listarray(l_arr: List[pa.Array]) -> pa.ListArray:
offsets = pa.array(np.cumsum([0] + [len(arr) for arr in l_arr]), type=pa.int32())
values = pa.concat_arrays(l_arr)
return pa.ListArray.from_arrays(offsets, values)


def list_of_np_array_to_pyarrow_listarray(l_arr: List[np.ndarray], type: pa.DataType = None) -> pa.ListArray:
"""Build a PyArrow ListArray from a possibly nested list of NumPy arrays"""
if len(l_arr) > 0:
return list_of_pa_arrays_to_pyarrow_listarray([numpy_to_pyarrow_listarray(arr, type=type) for arr in l_arr])
else:
return pa.array([], type=type)


class Features(dict):
@property
def type(self):
Expand Down