Skip to content

Commit ab6946d

Browse files
mariosaskolhoestq
andauthored
Fix embed_storage on features inside lists/sequences (#4615)
* Dedicated function for embedding data into table * Add test * minor * minor 2 * add test Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 77cce3c commit ab6946d

File tree

5 files changed

+191
-26
lines changed

5 files changed

+191
-26
lines changed

src/datasets/arrow_dataset.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@
8484
InMemoryTable,
8585
MemoryMappedTable,
8686
Table,
87-
cast_table_to_features,
8887
concat_tables,
88+
embed_table_storage,
8989
list_table_cache_files,
9090
table_cast,
9191
table_visitor,
@@ -95,7 +95,7 @@
9595
from .utils._hf_hub_fixes import create_repo
9696
from .utils.file_utils import _retry, cached_path, estimate_dataset_size, hf_hub_url
9797
from .utils.info_utils import is_small_dataset
98-
from .utils.py_utils import convert_file_size_to_int, temporary_assignment, unique_values
98+
from .utils.py_utils import convert_file_size_to_int, unique_values
9999
from .utils.stratify import stratified_shuffle_split_generate_indices
100100
from .utils.tf_utils import minimal_tf_collate_fn
101101
from .utils.typing import PathLike
@@ -4150,26 +4150,17 @@ def extra_nbytes_visitor(array, feature):
41504150
if decodable_columns:
41514151

41524152
def shards_with_embedded_external_files(shards):
4153-
# Temporarily assign the modified version of `cast_storage` before the cast to the decodable
4154-
# feature types to delete path information and embed file content in the arrow file.
4155-
with contextlib.ExitStack() as stack:
4156-
for decodable_feature_type in [Audio, Image]:
4157-
stack.enter_context(
4158-
temporary_assignment(
4159-
decodable_feature_type, "cast_storage", decodable_feature_type.embed_storage
4160-
)
4161-
)
4162-
for shard in shards:
4163-
format = shard.format
4164-
shard = shard.with_format("arrow")
4165-
shard = shard.map(
4166-
partial(cast_table_to_features, features=shard.features),
4167-
batched=True,
4168-
batch_size=1000,
4169-
keep_in_memory=True,
4170-
)
4171-
shard = shard.with_format(**format)
4172-
yield shard
4153+
for shard in shards:
4154+
format = shard.format
4155+
shard = shard.with_format("arrow")
4156+
shard = shard.map(
4157+
embed_table_storage,
4158+
batched=True,
4159+
batch_size=1000,
4160+
keep_in_memory=True,
4161+
)
4162+
shard = shard.with_format(**format)
4163+
yield shard
41734164

41744165
shards = shards_with_embedded_external_files(shards)
41754166

@@ -4224,7 +4215,9 @@ def path_in_repo(_index, shard):
42244215
for data_file in data_files
42254216
if data_file.startswith(f"data/{split}-") and data_file not in shards_path_in_repo
42264217
]
4227-
deleted_size = sum(xgetsize(hf_hub_url(repo_id, data_file)) for data_file in data_files_to_delete)
4218+
deleted_size = sum(
4219+
xgetsize(hf_hub_url(repo_id, data_file), use_auth_token=token) for data_file in data_files_to_delete
4220+
)
42284221

42294222
def delete_file(file):
42304223
api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch)

src/datasets/features/features.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,24 @@ def require_storage_cast(feature: FeatureType) -> bool:
14381438
return hasattr(feature, "cast_storage")
14391439

14401440

1441+
def require_storage_embed(feature: FeatureType) -> bool:
1442+
"""Check if a (possibly nested) feature requires embedding data into storage.
1443+
1444+
Args:
1445+
feature (FeatureType): the feature type to be checked
1446+
Returns:
1447+
:obj:`bool`
1448+
"""
1449+
if isinstance(feature, dict):
1450+
return any(require_storage_cast(f) for f in feature.values())
1451+
elif isinstance(feature, (list, tuple)):
1452+
return require_storage_cast(feature[0])
1453+
elif isinstance(feature, Sequence):
1454+
return require_storage_cast(feature.feature)
1455+
else:
1456+
return hasattr(feature, "embed_storage")
1457+
1458+
14411459
def keep_features_dicts_synced(func):
14421460
"""
14431461
Wrapper to keep the secondary dictionary, which tracks whether keys are decodable, of the :class:`datasets.Features` object

src/datasets/table.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,7 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
17791779
= if casting from numbers to strings and allow_number_to_str is False
17801780
17811781
Returns:
1782-
array (:obj:`pyarrow.Array`): the casted array
1782+
array (:obj:`pyarrow.Array`): the casted array
17831783
"""
17841784
from .features.features import Sequence, get_nested_type
17851785

@@ -1850,8 +1850,89 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
18501850
raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{feature}")
18511851

18521852

1853+
@_wrap_for_chunked_arrays
1854+
def embed_array_storage(array: pa.Array, feature: "FeatureType"):
1855+
"""Embed data into an arrays's storage.
1856+
For custom features like Audio or Image, it takes into account the "embed_storage" methods
1857+
they defined to enable embedding external data (e.g. an image file) into an other arrow types.
1858+
1859+
Args:
1860+
array (pa.Array): the PyArrow array in which to embed data
1861+
feature (FeatureType): array features
1862+
1863+
Raises:
1864+
TypeError: if the target type is not supported according, e.g.
1865+
1866+
- if a field is missing
1867+
1868+
Returns:
1869+
array (:obj:`pyarrow.Array`): the casted array
1870+
"""
1871+
from .features import Sequence
1872+
1873+
_e = embed_array_storage
1874+
1875+
if isinstance(array, pa.ExtensionArray):
1876+
array = array.storage
1877+
if hasattr(feature, "embed_storage"):
1878+
return feature.embed_storage(array)
1879+
elif pa.types.is_struct(array.type):
1880+
# feature must be a dict or Sequence(subfeatures_dict)
1881+
if isinstance(feature, Sequence) and isinstance(feature.feature, dict):
1882+
feature = {
1883+
name: Sequence(subfeature, length=feature.length) for name, subfeature in feature.feature.items()
1884+
}
1885+
if isinstance(feature, dict):
1886+
arrays = [_e(array.field(name), subfeature) for name, subfeature in feature.items()]
1887+
return pa.StructArray.from_arrays(arrays, names=list(feature), mask=array.is_null())
1888+
elif pa.types.is_list(array.type):
1889+
# feature must be either [subfeature] or Sequence(subfeature)
1890+
if isinstance(feature, list):
1891+
if array.null_count > 0:
1892+
warnings.warn(
1893+
f"None values are converted to empty lists when embedding array storage with {feature}. More info: https://github.com/huggingface/datasets/issues/3676. This will raise an error in a future major version of `datasets`"
1894+
)
1895+
return pa.ListArray.from_arrays(array.offsets, _e(array.values, feature[0]))
1896+
elif isinstance(feature, Sequence):
1897+
if feature.length > -1:
1898+
if feature.length * len(array) == len(array.values):
1899+
return pa.FixedSizeListArray.from_arrays(_e(array.values, feature.feature), feature.length)
1900+
else:
1901+
casted_values = _e(array.values, feature.feature)
1902+
if casted_values.type == array.values.type:
1903+
return array
1904+
else:
1905+
if array.null_count > 0:
1906+
warnings.warn(
1907+
f"None values are converted to empty lists when embedding array storage with {feature}. More info: https://github.com/huggingface/datasets/issues/3676. This will raise an error in a future major version of `datasets`"
1908+
)
1909+
return pa.ListArray.from_arrays(array.offsets, _e(array.values, feature.feature))
1910+
elif pa.types.is_fixed_size_list(array.type):
1911+
# feature must be either [subfeature] or Sequence(subfeature)
1912+
if isinstance(feature, list):
1913+
if array.null_count > 0:
1914+
warnings.warn(
1915+
f"None values are converted to empty lists when embedding array storage with {feature}. More info: https://github.com/huggingface/datasets/issues/3676. This will raise an error in a future major version of `datasets`"
1916+
)
1917+
return pa.ListArray.from_arrays(array.offsets, _e(array.values, feature[0]))
1918+
elif isinstance(feature, Sequence):
1919+
if feature.length > -1:
1920+
if feature.length * len(array) == len(array.values):
1921+
return pa.FixedSizeListArray.from_arrays(_e(array.values, feature.feature), feature.length)
1922+
else:
1923+
offsets_arr = pa.array(range(len(array) + 1), pa.int32())
1924+
if array.null_count > 0:
1925+
warnings.warn(
1926+
f"None values are converted to empty lists when embedding array storage with {feature}. More info: https://github.com/huggingface/datasets/issues/3676. This will raise an error in a future major version of `datasets`"
1927+
)
1928+
return pa.ListArray.from_arrays(offsets_arr, _e(array.values, feature.feature))
1929+
if not isinstance(feature, (Sequence, dict, list, tuple)):
1930+
return array
1931+
raise TypeError(f"Couldn't embed array of type\n{array.type}\nwith\n{feature}")
1932+
1933+
18531934
def cast_table_to_features(table: pa.Table, features: "Features"):
1854-
"""Cast an table to the arrow schema that corresponds to the requested features.
1935+
"""Cast a table to the arrow schema that corresponds to the requested features.
18551936
18561937
Args:
18571938
table (:obj:`pyarrow.Table`): PyArrow table to cast
@@ -1885,6 +1966,25 @@ def cast_table_to_schema(table: pa.Table, schema: pa.Schema):
18851966
return pa.Table.from_arrays(arrays, schema=schema)
18861967

18871968

1969+
def embed_table_storage(table: pa.Table):
1970+
"""Embed external data into a table's storage.
1971+
1972+
Args:
1973+
table (:obj:`pyarrow.Table`): PyArrow table in which to embed data
1974+
1975+
Returns:
1976+
table (:obj:`pyarrow.Table`): the table with embedded data
1977+
"""
1978+
from .features.features import Features, require_storage_embed
1979+
1980+
features = Features.from_arrow_schema(table.schema)
1981+
arrays = [
1982+
embed_array_storage(table[name], feature) if require_storage_embed(feature) else table[name]
1983+
for name, feature in features.items()
1984+
]
1985+
return pa.Table.from_arrays(arrays, schema=features.arrow_schema)
1986+
1987+
18881988
def table_cast(table: pa.Table, schema: pa.Schema):
18891989
"""Improved version of pa.Table.cast.
18901990

tests/test_table.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from datasets import Sequence, Value
10-
from datasets.features.features import ClassLabel, Features
10+
from datasets.features.features import ClassLabel, Features, Image
1111
from datasets.table import (
1212
ConcatenationTable,
1313
InMemoryTable,
@@ -20,7 +20,10 @@
2020
_memory_mapped_arrow_table_from_file,
2121
cast_array_to_feature,
2222
concat_tables,
23+
embed_array_storage,
24+
embed_table_storage,
2325
inject_arrow_table_documentation,
26+
table_cast,
2427
)
2528

2629
from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow
@@ -1045,3 +1048,29 @@ def test_cast_array_to_features_to_null_type():
10451048
arr = pa.array([[None, 1]])
10461049
with pytest.raises(TypeError):
10471050
cast_array_to_feature(arr, Sequence(Value("null")))
1051+
1052+
1053+
def test_embed_array_storage(image_file):
1054+
array = pa.array([{"bytes": None, "path": image_file}], type=Image.pa_type)
1055+
embedded_images_array = embed_array_storage(array, Image())
1056+
assert embedded_images_array.to_pylist()[0]["path"] is None
1057+
assert isinstance(embedded_images_array.to_pylist()[0]["bytes"], bytes)
1058+
1059+
1060+
def test_embed_array_storage_nested(image_file):
1061+
array = pa.array([[{"bytes": None, "path": image_file}]], type=pa.list_(Image.pa_type))
1062+
embedded_images_array = embed_array_storage(array, [Image()])
1063+
assert embedded_images_array.to_pylist()[0][0]["path"] is None
1064+
assert isinstance(embedded_images_array.to_pylist()[0][0]["bytes"], bytes)
1065+
array = pa.array([{"foo": {"bytes": None, "path": image_file}}], type=pa.struct({"foo": Image.pa_type}))
1066+
embedded_images_array = embed_array_storage(array, {"foo": Image()})
1067+
assert embedded_images_array.to_pylist()[0]["foo"]["path"] is None
1068+
assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["bytes"], bytes)
1069+
1070+
1071+
def test_embed_table_storage(image_file):
1072+
features = Features({"image": Image()})
1073+
table = table_cast(pa.table({"image": [image_file]}), features.arrow_schema)
1074+
embedded_images_table = embed_table_storage(table)
1075+
assert embedded_images_table.to_pydict()["image"][0]["path"] is None
1076+
assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes)

tests/test_upstream_hub.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,31 @@ def test_push_dataset_to_hub_custom_features_image(self):
417417
finally:
418418
self.cleanup_repo(ds_name)
419419

420+
@require_pil
421+
def test_push_dataset_to_hub_custom_features_image_list(self):
422+
image_path = os.path.join(os.path.dirname(__file__), "features", "data", "test_image_rgb.jpg")
423+
data = {"x": [[image_path], [image_path, image_path]], "y": [0, -1]}
424+
features = Features({"x": [Image()], "y": Value("int32")})
425+
ds = Dataset.from_dict(data, features=features)
426+
427+
for embed_external_files in [True, False]:
428+
ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
429+
try:
430+
ds.push_to_hub(ds_name, embed_external_files=embed_external_files, token=self._token)
431+
hub_ds = load_dataset(ds_name, split="train", download_mode="force_redownload")
432+
433+
self.assertListEqual(ds.column_names, hub_ds.column_names)
434+
self.assertListEqual(list(ds.features.keys()), list(hub_ds.features.keys()))
435+
self.assertDictEqual(ds.features, hub_ds.features)
436+
self.assertEqual(ds[:], hub_ds[:])
437+
hub_ds = hub_ds.cast_column("x", [Image(decode=False)])
438+
elem = hub_ds[0]["x"][0]
439+
path, bytes_ = elem["path"], elem["bytes"]
440+
self.assertTrue(bool(path) == (not embed_external_files))
441+
self.assertTrue(bool(bytes_) == embed_external_files)
442+
finally:
443+
self.cleanup_repo(ds_name)
444+
420445
def test_push_dataset_dict_to_hub_custom_features(self):
421446
features = Features({"x": Value("int64"), "y": ClassLabel(names=["neg", "pos"])})
422447
ds = Dataset.from_dict({"x": [1, 2, 3], "y": [0, 0, 1]}, features=features)

0 commit comments

Comments
 (0)