Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
"py7zr",
"rarfile>=4.0",
"sqlalchemy<2.0.0",
"sqlalchemy",
"s3fs>=2021.11.1", # aligned with fsspec[http]>=2021.11.1; test only on python 3.7 for now
"tensorflow>=2.3,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
"tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
key: _cast_to_python_objects(
value, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
)[0]
for key, value in obj.to_dict("list").items()
for key, value in obj.to_dict("series").items()
},
True,
)
Expand Down
5 changes: 2 additions & 3 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,9 +2076,6 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
elif pa.types.is_fixed_size_list(array.type):
# feature must be either [subfeature] or Sequence(subfeature)
array_values = array.values
if config.PYARROW_VERSION.major < 15:
# PyArrow bug: https://github.com/apache/arrow/issues/35360
array_values = array.values[array.offset * array.type.list_size :]
if isinstance(feature, list):
if array.null_count > 0:
if config.PYARROW_VERSION.major < 10:
Expand All @@ -2090,6 +2087,8 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
return pa.ListArray.from_arrays(array.offsets, _c(array_values, feature[0]))
elif isinstance(feature, Sequence):
if feature.length > -1:
if array.offset and feature.length * len(array) != len(array_values):
array_values = array.values[array.offset * array.type.list_size :]
if feature.length * len(array) == len(array_values):
return pa.FixedSizeListArray.from_arrays(_c(array_values, feature.feature), feature.length)
else:
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def set_update_download_counts_to_false(monkeypatch):
def set_sqlalchemy_silence_uber_warning(monkeypatch):
# Required to suppress RemovedIn20Warning when feature(s) are not compatible with SQLAlchemy 2.0
# To be removed once SQLAlchemy 2.0 supported
monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True)
try:
monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True)
except AttributeError:
pass


@pytest.fixture(autouse=True, scope="session")
Expand Down