Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
- name: Install dependencies (latest versions)
if: ${{ matrix.deps_versions == 'deps-latest' }}
run: pip install --upgrade pyarrow huggingface-hub dill
- name: Install depencencies (minimum versions)
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: pip install pyarrow==8.0.0 huggingface-hub==0.18.0 transformers dill==0.3.1.1
- name: Test with pytest
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
# Backend and serialization.
# Minimum 8.0.0 to be able to use .to_reader()
"pyarrow>=8.0.0",
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
"pyarrow-hotfix",
# For smart caching dataset processing
"dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
# For performance gains with apache arrow
Expand Down
26 changes: 20 additions & 6 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.types
import pyarrow_hotfix # noqa: F401 # to fix vulnerability on pyarrow<14.0.1
from pandas.api.extensions import ExtensionArray as PandasExtensionArray
from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype

Expand Down Expand Up @@ -631,7 +632,7 @@ class Array5D(_ArrayXD):
_type: str = field(default="Array5D", init=False, repr=False)


class _ArrayXDExtensionType(pa.PyExtensionType):
class _ArrayXDExtensionType(pa.ExtensionType):
ndims: Optional[int] = None

def __init__(self, shape: tuple, dtype: str):
Expand All @@ -645,13 +646,19 @@ def __init__(self, shape: tuple, dtype: str):
self.shape = tuple(shape)
self.value_type = dtype
self.storage_dtype = self._generate_dtype(self.value_type)
pa.PyExtensionType.__init__(self, self.storage_dtype)
pa.ExtensionType.__init__(self, self.storage_dtype, f"{self.__class__.__module__}.{self.__class__.__name__}")

def __arrow_ext_serialize__(self):
return json.dumps((self.shape, self.value_type)).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
args = json.loads(serialized)
return cls(*args)

# This was added to pa.ExtensionType in pyarrow >= 13.0.0
def __reduce__(self):
return self.__class__, (
self.shape,
self.value_type,
)
return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __hash__(self):
return hash((self.__class__, self.shape, self.value_type))
Expand Down Expand Up @@ -687,6 +694,13 @@ class Array5DExtensionType(_ArrayXDExtensionType):
ndims = 5


# Register the extension types for deserialization
pa.register_extension_type(Array2DExtensionType((1, 2), "int64"))
pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64"))
pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64"))
pa.register_extension_type(Array5DExtensionType((1, 2, 3, 4, 5), "int64"))


def _is_zero_copy_only(pa_type: pa.DataType, unnest: bool = False) -> bool:
"""
When converting a pyarrow array to a numpy array, we must know whether this could be done in zero-copy or not.
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,7 @@ def _offsets_concat(offsets):

def _concat_arrays(arrays):
array_type = arrays[0].type
if isinstance(array_type, pa.PyExtensionType):
if isinstance(array_type, pa.ExtensionType):
return array_type.wrap_array(_concat_arrays([array.storage for array in arrays]))
elif pa.types.is_struct(array_type):
return pa.StructArray.from_arrays(
Expand Down