Skip to content

Commit c3c45db

Browse files
authored
Merge pull request #1 from YQ-Wang/1.15.1-fix
Merge in fix from huggingface#6404
2 parents 0181006 + 976f3e4 commit c3c45db

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
# Minimum 3.0.0 to support mix of struct and list types in parquet, and batch iterators of parquet data
7878
# pyarrow 4.0.0 introduced segfault bug, see: https://github.com/huggingface/datasets/pull/2268
7979
"pyarrow>=1.0.0,!=4.0.0",
80+
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
81+
"pyarrow-hotfix",
8082
# For smart caching dataset processing
8183
"dill",
8284
# For performance gains with apache arrow

src/datasets/features/features.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Lint as: python3
1717
""" This class handle features definition in datasets and some utilities to display table type."""
1818
import copy
19+
import json
1920
import re
2021
import sys
2122
from collections.abc import Iterable
@@ -30,6 +31,7 @@
3031
import pandas as pd
3132
import pyarrow as pa
3233
import pyarrow.types
34+
import pyarrow_hotfix # noqa: F401 # to fix vulnerability on pyarrow<14.0.1
3335
from pandas.api.extensions import ExtensionArray as PandasExtensionArray
3436
from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype
3537
from pyarrow.lib import TimestampType
@@ -353,7 +355,7 @@ class Array5D(_ArrayXD):
353355
_type: str = field(default="Array5D", init=False, repr=False)
354356

355357

356-
class _ArrayXDExtensionType(pa.PyExtensionType):
358+
class _ArrayXDExtensionType(pa.ExtensionType):
357359
ndims: Optional[int] = None
358360

359361
def __init__(self, shape: tuple, dtype: str):
@@ -364,13 +366,18 @@ def __init__(self, shape: tuple, dtype: str):
364366
self.shape = tuple(shape)
365367
self.value_type = dtype
366368
self.storage_dtype = self._generate_dtype(self.value_type)
367-
pa.PyExtensionType.__init__(self, self.storage_dtype)
369+
pa.ExtensionType.__init__(self, self.storage_dtype, f"{self.__class__.__module__}.{self.__class__.__name__}")
370+
371+
def __arrow_ext_serialize__(self):
372+
return json.dumps((self.shape, self.value_type)).encode()
373+
374+
@classmethod
375+
def __arrow_ext_deserialize__(cls, storage_type, serialized):
376+
args = json.loads(serialized)
377+
return cls(*args)
368378

369379
def __reduce__(self):
370-
return self.__class__, (
371-
self.shape,
372-
self.value_type,
373-
)
380+
return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())
374381

375382
def __arrow_ext_class__(self):
376383
return ArrayExtensionArray
@@ -403,6 +410,13 @@ class Array5DExtensionType(_ArrayXDExtensionType):
403410
ndims = 5
404411

405412

413+
# Register the extension types for deserialization
414+
pa.register_extension_type(Array2DExtensionType((1, 2), "int64"))
415+
pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64"))
416+
pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64"))
417+
pa.register_extension_type(Array5DExtensionType((1, 2, 3, 4, 5), "int64"))
418+
419+
406420
def _is_zero_copy_only(pa_type: pa.DataType) -> bool:
407421
"""
408422
When converting a pyarrow array to a numpy array, we must know whether this could be done in zero-copy or not.

0 commit comments

Comments
 (0)