Skip to content

Commit c096bd2

Browse files
Support pyarrow 14.0.1 and fix vulnerability CVE-2023-47248 (#6404)
* Replace pa.PyExtensionType with pa.ExtensionType * Register user-defined extension types * Pin minimum pyarrow version to 14.0.1 * Temporarily pin minimum pyarrow due to beam constraint * Remove constraint on pyarrow by removing unneeded upper beam version * Reset pyarrow minimum due to apache-beam constraint * Revert last 2 commits * Revert minimum pyarrow version and use pyarrow-hotfix * Add pa.ExtensionType.__reduce__
1 parent 30caa09 commit c096bd2

File tree

4 files changed

+24
-8
lines changed

4 files changed

+24
-8
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
- name: Install dependencies (latest versions)
6363
if: ${{ matrix.deps_versions == 'deps-latest' }}
6464
run: pip install --upgrade pyarrow huggingface-hub dill
65-
- name: Install depencencies (minimum versions)
65+
- name: Install dependencies (minimum versions)
6666
if: ${{ matrix.deps_versions != 'deps-latest' }}
6767
run: pip install pyarrow==8.0.0 huggingface-hub==0.18.0 transformers dill==0.3.1.1
6868
- name: Test with pytest

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@
113113
# Backend and serialization.
114114
# Minimum 8.0.0 to be able to use .to_reader()
115115
"pyarrow>=8.0.0",
116+
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
117+
"pyarrow-hotfix",
116118
# For smart caching dataset processing
117119
"dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
118120
# For performance gains with apache arrow

src/datasets/features/features.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import pyarrow as pa
3232
import pyarrow.compute as pc
3333
import pyarrow.types
34+
import pyarrow_hotfix # noqa: F401 # to fix vulnerability on pyarrow<14.0.1
3435
from pandas.api.extensions import ExtensionArray as PandasExtensionArray
3536
from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype
3637

@@ -631,7 +632,7 @@ class Array5D(_ArrayXD):
631632
_type: str = field(default="Array5D", init=False, repr=False)
632633

633634

634-
class _ArrayXDExtensionType(pa.PyExtensionType):
635+
class _ArrayXDExtensionType(pa.ExtensionType):
635636
ndims: Optional[int] = None
636637

637638
def __init__(self, shape: tuple, dtype: str):
@@ -645,13 +646,19 @@ def __init__(self, shape: tuple, dtype: str):
645646
self.shape = tuple(shape)
646647
self.value_type = dtype
647648
self.storage_dtype = self._generate_dtype(self.value_type)
648-
pa.PyExtensionType.__init__(self, self.storage_dtype)
649+
pa.ExtensionType.__init__(self, self.storage_dtype, f"{self.__class__.__module__}.{self.__class__.__name__}")
649650

651+
def __arrow_ext_serialize__(self):
652+
return json.dumps((self.shape, self.value_type)).encode()
653+
654+
@classmethod
655+
def __arrow_ext_deserialize__(cls, storage_type, serialized):
656+
args = json.loads(serialized)
657+
return cls(*args)
658+
659+
# This was added to pa.ExtensionType in pyarrow >= 13.0.0
650660
def __reduce__(self):
651-
return self.__class__, (
652-
self.shape,
653-
self.value_type,
654-
)
661+
return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())
655662

656663
def __hash__(self):
657664
return hash((self.__class__, self.shape, self.value_type))
@@ -687,6 +694,13 @@ class Array5DExtensionType(_ArrayXDExtensionType):
687694
ndims = 5
688695

689696

697+
# Register the extension types for deserialization
698+
pa.register_extension_type(Array2DExtensionType((1, 2), "int64"))
699+
pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64"))
700+
pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64"))
701+
pa.register_extension_type(Array5DExtensionType((1, 2, 3, 4, 5), "int64"))
702+
703+
690704
def _is_zero_copy_only(pa_type: pa.DataType, unnest: bool = False) -> bool:
691705
"""
692706
When converting a pyarrow array to a numpy array, we must know whether this could be done in zero-copy or not.

src/datasets/table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,7 @@ def _offsets_concat(offsets):
18911891

18921892
def _concat_arrays(arrays):
18931893
array_type = arrays[0].type
1894-
if isinstance(array_type, pa.PyExtensionType):
1894+
if isinstance(array_type, pa.ExtensionType):
18951895
return array_type.wrap_array(_concat_arrays([array.storage for array in arrays]))
18961896
elif pa.types.is_struct(array_type):
18971897
return pa.StructArray.from_arrays(

0 commit comments

Comments
 (0)