Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,10 +1240,7 @@ def encode_nested_example(schema, obj, level=0):
if level == 0 and obj is None:
raise ValueError("Got None but expected a dictionary instead")
return (
{
k: encode_nested_example(sub_schema, sub_obj, level=level + 1)
for k, (sub_schema, sub_obj) in zip_dict(schema, obj)
}
{k: encode_nested_example(schema[k], obj.get(k), level=level + 1) for k in schema}
if obj is not None
else None
)
Expand All @@ -1269,13 +1266,17 @@ def encode_nested_example(schema, obj, level=0):
list_dict = {}
if isinstance(obj, (list, tuple)):
# obj is a list of dict
for k, dict_tuples in zip_dict(schema.feature, *obj):
list_dict[k] = [encode_nested_example(dict_tuples[0], o, level=level + 1) for o in dict_tuples[1:]]
for k in schema.feature:
list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), level=level + 1) for o in obj]
return list_dict
else:
# obj is a single dict
for k, (sub_schema, sub_objs) in zip_dict(schema.feature, obj):
list_dict[k] = [encode_nested_example(sub_schema, o, level=level + 1) for o in sub_objs]
for k in schema.feature:
list_dict[k] = (
[encode_nested_example(schema.feature[k], o, level=level + 1) for o in obj[k]]
if k in obj
else None
)
return list_dict
# schema.feature is not a dict
if isinstance(obj, str): # don't interpret a string as a list
Expand Down
95 changes: 76 additions & 19 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class WebDataset(datasets.GeneratorBasedBuilder):
DEFAULT_WRITER_BATCH_SIZE = 100
IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script
AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script
DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script
NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5

Expand Down Expand Up @@ -65,34 +66,46 @@ def _split_generators(self, dl_manager):
name=split_name, gen_kwargs={"tar_paths": tar_paths, "tar_iterators": tar_iterators}
)
)

# Get one example to get the feature types
pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0])
first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE))
if any(example.keys() != first_examples[0].keys() for example in first_examples):
raise ValueError(
"The TAR archives of the dataset should be in WebDataset format, "
"but the files in the archive don't share the same prefix or the same types."
)
inferred_arrow_schema = pa.Table.from_pylist(first_examples[:1]).schema
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)

# Set Image types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
if extension in self.IMAGE_EXTENSIONS:
features[field_name] = datasets.Image()
self.info.features = features
if not self.info.features:
# Get one example to get the feature types
pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0])
first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE))
if any(example.keys() != first_examples[0].keys() for example in first_examples):
raise ValueError(
"The TAR archives of the dataset should be in WebDataset format, "
"but the files in the archive don't share the same prefix or the same types."
)
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
if datasets.config.PYARROW_VERSION.major < 14:
inferred_arrow_schema = pa.concat_tables(pa_tables, promote=True).schema
else:
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)

# Set Image types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
if extension in self.IMAGE_EXTENSIONS:
features[field_name] = datasets.Image()
# Set Audio types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
if extension in self.AUDIO_EXTENSIONS:
features[field_name] = datasets.Audio()
self.info.features = features

return splits

def _generate_examples(self, tar_paths, tar_iterators):
image_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image)
]
audio_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio)
]
for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)):
for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)):
for field_name in image_field_names:
for field_name in image_field_names + audio_field_names:
example[field_name] = {"path": example["__key__"] + "." + field_name, "bytes": example[field_name]}
yield f"{tar_idx}_{example_idx}", example

Expand Down Expand Up @@ -177,6 +190,50 @@ def _generate_examples(self, tar_paths, tar_iterators):
WebDataset.IMAGE_EXTENSIONS = IMAGE_EXTENSIONS


# Obtained with:
# ```
# import soundfile as sf
#
# AUDIO_EXTENSIONS = [f".{format.lower()}" for format in sf.available_formats().keys()]
#
# # .mp3 is currently decoded via `torchaudio`, .opus decoding is supported if version of `libsndfile` >= 1.0.30:
# AUDIO_EXTENSIONS.extend([".mp3", ".opus"])
# ```
# We intentionally do not run this code on launch because:
# (1) Soundfile is an optional dependency, so importing it in global namespace is not allowed
# (2) To ensure the list of supported extensions is deterministic
AUDIO_EXTENSIONS = [
"aiff",
"au",
"avr",
"caf",
"flac",
"htk",
"svx",
"mat4",
"mat5",
"mpc2k",
"ogg",
"paf",
"pvf",
"raw",
"rf64",
"sd2",
"sds",
"ircam",
"voc",
"w64",
"wav",
"nist",
"wavex",
"wve",
"xi",
"mp3",
"opus",
]
WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS


def text_loads(data: bytes):
return data.decode("utf-8")

Expand Down
112 changes: 98 additions & 14 deletions tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,67 @@
import json
import tarfile

import numpy as np
import pytest

from datasets import DownloadManager, Features, Image, Value
from datasets import Audio, DownloadManager, Features, Image, Value
from datasets.packaged_modules.webdataset.webdataset import WebDataset

from ..utils import require_pil
from ..utils import require_pil, require_sndfile


@pytest.fixture
def tar_file(tmp_path, image_file, text_file):
def image_wds_file(tmp_path, image_file):
json_file = tmp_path / "data.json"
filename = tmp_path / "file.tar"
num_examples = 3
with json_file.open("w", encoding="utf-8") as f:
f.write(json.dumps({"caption": "this is an image"}))
with tarfile.open(str(filename), "w") as f:
for example_idx in range(num_examples):
f.add(text_file, f"{example_idx:05d}.txt")
f.add(json_file, f"{example_idx:05d}.json")
f.add(image_file, f"{example_idx:05d}.jpg")
return str(filename)


@pytest.fixture
def bad_tar_file(tmp_path, image_file, text_file):
def audio_wds_file(tmp_path, audio_file):
json_file = tmp_path / "data.json"
filename = tmp_path / "file.tar"
num_examples = 3
with json_file.open("w", encoding="utf-8") as f:
f.write(json.dumps({"transcript": "this is a transcript"}))
with tarfile.open(str(filename), "w") as f:
for example_idx in range(num_examples):
f.add(json_file, f"{example_idx:05d}.json")
f.add(audio_file, f"{example_idx:05d}.wav")
return str(filename)


@pytest.fixture
def bad_wds_file(tmp_path, image_file, text_file):
json_file = tmp_path / "data.json"
filename = tmp_path / "bad_file.tar"
with json_file.open("w", encoding="utf-8") as f:
f.write(json.dumps({"caption": "this is an image"}))
with tarfile.open(str(filename), "w") as f:
f.add(image_file)
f.add(text_file)
f.add(json_file)
return str(filename)


@require_pil
def test_webdataset(tar_file):
def test_image_webdataset(image_wds_file):
import PIL.Image

data_files = {"train": [tar_file]}
data_files = {"train": [image_wds_file]}
webdataset = WebDataset(data_files=data_files)
split_generators = webdataset._split_generators(DownloadManager())
assert webdataset.info.features == Features(
{
"__key__": Value("string"),
"__url__": Value("string"),
"txt": Value("string"),
"json": {"caption": Value("string")},
"jpg": Image(),
}
)
Expand All @@ -49,15 +71,77 @@ def test_webdataset(tar_file):
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
_, examples = zip(*generator)
assert len(examples) == 3
assert isinstance(examples[0]["txt"], str)
assert isinstance(examples[0]["json"], dict)
assert isinstance(examples[0]["json"]["caption"], str)
assert isinstance(examples[0]["jpg"], dict) # keep encoded to avoid unecessary copies
decoded = webdataset.info.features.decode_example(examples[0])
assert isinstance(decoded["txt"], str)
encoded = webdataset.info.features.encode_example(examples[0])
decoded = webdataset.info.features.decode_example(encoded)
assert isinstance(decoded["json"], dict)
assert isinstance(decoded["json"]["caption"], str)
assert isinstance(decoded["jpg"], PIL.Image.Image)


def test_webdataset_errors_on_bad_file(bad_tar_file):
data_files = {"train": [bad_tar_file]}
@require_sndfile
def test_audio_webdataset(audio_wds_file):
data_files = {"train": [audio_wds_file]}
webdataset = WebDataset(data_files=data_files)
split_generators = webdataset._split_generators(DownloadManager())
assert webdataset.info.features == Features(
{
"__key__": Value("string"),
"__url__": Value("string"),
"json": {"transcript": Value("string")},
"wav": Audio(),
}
)
assert len(split_generators) == 1
split_generator = split_generators[0]
assert split_generator.name == "train"
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
_, examples = zip(*generator)
assert len(examples) == 3
assert isinstance(examples[0]["json"], dict)
assert isinstance(examples[0]["json"]["transcript"], str)
assert isinstance(examples[0]["wav"], dict)
assert isinstance(examples[0]["wav"]["bytes"], bytes) # keep encoded to avoid unecessary copies
encoded = webdataset.info.features.encode_example(examples[0])
decoded = webdataset.info.features.decode_example(encoded)
assert isinstance(decoded["json"], dict)
assert isinstance(decoded["json"]["transcript"], str)
assert isinstance(decoded["wav"], dict)
assert isinstance(decoded["wav"]["array"], np.ndarray)


def test_webdataset_errors_on_bad_file(bad_wds_file):
data_files = {"train": [bad_wds_file]}
webdataset = WebDataset(data_files=data_files)
with pytest.raises(ValueError):
webdataset._split_generators(DownloadManager())


@require_pil
def test_webdataset_with_features(image_wds_file):
import PIL.Image

data_files = {"train": [image_wds_file]}
features = Features(
{
"__key__": Value("string"),
"__url__": Value("string"),
"json": {"caption": Value("string"), "additional_field": Value("int64")},
"jpg": Image(),
}
)
webdataset = WebDataset(data_files=data_files, features=features)
split_generators = webdataset._split_generators(DownloadManager())
assert webdataset.info.features == features
split_generator = split_generators[0]
assert split_generator.name == "train"
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
_, example = next(iter(generator))
encoded = webdataset.info.features.encode_example(example)
decoded = webdataset.info.features.decode_example(encoded)
assert decoded["json"]["additional_field"] is None
assert isinstance(decoded["json"], dict)
assert isinstance(decoded["json"]["caption"], str)
assert isinstance(decoded["jpg"], PIL.Image.Image)