Skip to content

Commit 129b9e0

Browse files
lhoestqmariosasko
andauthored
Fix CI: pyarrow 15, pandas 2.2 and sqlachemy (#6617)
* fix cast sliced fixed size list for pyarrow 15 * fix pandas timedelta * unpin qslalchemy * Update src/datasets/table.py Co-authored-by: Mario Šaško <[email protected]> * style --------- Co-authored-by: Mario Šaško <[email protected]>
1 parent 3267234 commit 129b9e0

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@
175175
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
176176
"py7zr",
177177
"rarfile>=4.0",
178-
"sqlalchemy<2.0.0",
178+
"sqlalchemy",
179179
"s3fs>=2021.11.1", # aligned with fsspec[http]>=2021.11.1; test only on python 3.7 for now
180180
"tensorflow>=2.3,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
181181
"tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",

src/datasets/features/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
370370
key: _cast_to_python_objects(
371371
value, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
372372
)[0]
373-
for key, value in obj.to_dict("list").items()
373+
for key, value in obj.to_dict("series").items()
374374
},
375375
True,
376376
)

src/datasets/table.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,9 +2076,6 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
20762076
elif pa.types.is_fixed_size_list(array.type):
20772077
# feature must be either [subfeature] or Sequence(subfeature)
20782078
array_values = array.values
2079-
if config.PYARROW_VERSION.major < 15:
2080-
# PyArrow bug: https://github.com/apache/arrow/issues/35360
2081-
array_values = array.values[array.offset * array.type.list_size :]
20822079
if isinstance(feature, list):
20832080
if array.null_count > 0:
20842081
if config.PYARROW_VERSION.major < 10:
@@ -2090,6 +2087,10 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
20902087
return pa.ListArray.from_arrays(array.offsets, _c(array_values, feature[0]))
20912088
elif isinstance(feature, Sequence):
20922089
if feature.length > -1:
2090+
if array.offset and feature.length * len(array) != len(array_values):
2091+
array_values = array.values[
2092+
array.offset * array.type.list_size : (array.offset + len(array)) * array.type.list_size
2093+
]
20932094
if feature.length * len(array) == len(array_values):
20942095
return pa.FixedSizeListArray.from_arrays(_c(array_values, feature.feature), feature.length)
20952096
else:

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def set_update_download_counts_to_false(monkeypatch):
5151
def set_sqlalchemy_silence_uber_warning(monkeypatch):
5252
# Required to suppress RemovedIn20Warning when feature(s) are not compatible with SQLAlchemy 2.0
5353
# To be removed once SQLAlchemy 2.0 supported
54-
monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True)
54+
try:
55+
monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True)
56+
except AttributeError:
57+
pass
5558

5659

5760
@pytest.fixture(autouse=True, scope="session")

0 commit comments

Comments
 (0)