From 827b230dc8aaefca287945de91f4db7ebbeb558e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 9 Jan 2024 16:01:14 +0100 Subject: [PATCH] audio support and bug fix --- src/datasets/features/features.py | 17 +-- .../packaged_modules/webdataset/webdataset.py | 95 ++++++++++++--- tests/packaged_modules/test_webdataset.py | 112 +++++++++++++++--- 3 files changed, 183 insertions(+), 41 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 6ebdb48741d..344a479612a 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1240,10 +1240,7 @@ def encode_nested_example(schema, obj, level=0): if level == 0 and obj is None: raise ValueError("Got None but expected a dictionary instead") return ( - { - k: encode_nested_example(sub_schema, sub_obj, level=level + 1) - for k, (sub_schema, sub_obj) in zip_dict(schema, obj) - } + {k: encode_nested_example(schema[k], obj.get(k), level=level + 1) for k in schema} if obj is not None else None ) @@ -1269,13 +1266,17 @@ def encode_nested_example(schema, obj, level=0): list_dict = {} if isinstance(obj, (list, tuple)): # obj is a list of dict - for k, dict_tuples in zip_dict(schema.feature, *obj): - list_dict[k] = [encode_nested_example(dict_tuples[0], o, level=level + 1) for o in dict_tuples[1:]] + for k in schema.feature: + list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), level=level + 1) for o in obj] return list_dict else: # obj is a single dict - for k, (sub_schema, sub_objs) in zip_dict(schema.feature, obj): - list_dict[k] = [encode_nested_example(sub_schema, o, level=level + 1) for o in sub_objs] + for k in schema.feature: + list_dict[k] = ( + [encode_nested_example(schema.feature[k], o, level=level + 1) for o in obj[k]] + if k in obj + else None + ) return list_dict # schema.feature is not a dict if isinstance(obj, str): # don't interpret a string as a list diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index e44bf50857c..3ac1e86fc41 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -15,6 +15,7 @@ class WebDataset(datasets.GeneratorBasedBuilder): DEFAULT_WRITER_BATCH_SIZE = 100 IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script + AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5 @@ -65,24 +66,33 @@ def _split_generators(self, dl_manager): name=split_name, gen_kwargs={"tar_paths": tar_paths, "tar_iterators": tar_iterators} ) ) - - # Get one example to get the feature types - pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0]) - first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE)) - if any(example.keys() != first_examples[0].keys() for example in first_examples): - raise ValueError( - "The TAR archives of the dataset should be in WebDataset format, " - "but the files in the archive don't share the same prefix or the same types." - ) - inferred_arrow_schema = pa.Table.from_pylist(first_examples[:1]).schema - features = datasets.Features.from_arrow_schema(inferred_arrow_schema) - - # Set Image types - for field_name in first_examples[0]: - extension = field_name.rsplit(".", 1)[-1] - if extension in self.IMAGE_EXTENSIONS: - features[field_name] = datasets.Image() - self.info.features = features + if not self.info.features: + # Get one example to get the feature types + pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0]) + first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE)) + if any(example.keys() != first_examples[0].keys() for example in first_examples): + raise ValueError( + "The TAR archives of the dataset should be in WebDataset format, " + "but the files in the archive don't share the same prefix or the same types." + ) + pa_tables = [pa.Table.from_pylist([example]) for example in first_examples] + if datasets.config.PYARROW_VERSION.major < 14: + inferred_arrow_schema = pa.concat_tables(pa_tables, promote=True).schema + else: + inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema + features = datasets.Features.from_arrow_schema(inferred_arrow_schema) + + # Set Image types + for field_name in first_examples[0]: + extension = field_name.rsplit(".", 1)[-1] + if extension in self.IMAGE_EXTENSIONS: + features[field_name] = datasets.Image() + # Set Audio types + for field_name in first_examples[0]: + extension = field_name.rsplit(".", 1)[-1] + if extension in self.AUDIO_EXTENSIONS: + features[field_name] = datasets.Audio() + self.info.features = features return splits @@ -90,9 +100,12 @@ def _generate_examples(self, tar_paths, tar_iterators): image_field_names = [ field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image) ] + audio_field_names = [ + field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio) + ] for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)): for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)): - for field_name in image_field_names: + for field_name in image_field_names + audio_field_names: example[field_name] = {"path": example["__key__"] + "." + field_name, "bytes": example[field_name]} yield f"{tar_idx}_{example_idx}", example @@ -177,6 +190,50 @@ def _generate_examples(self, tar_paths, tar_iterators): WebDataset.IMAGE_EXTENSIONS = IMAGE_EXTENSIONS +# Obtained with: +# ``` +# import soundfile as sf +# +# AUDIO_EXTENSIONS = [f".{format.lower()}" for format in sf.available_formats().keys()] +# +# # .mp3 is currently decoded via `torchaudio`, .opus decoding is supported if version of `libsndfile` >= 1.0.30: +# AUDIO_EXTENSIONS.extend([".mp3", ".opus"]) +# ``` +# We intentionally do not run this code on launch because: +# (1) Soundfile is an optional dependency, so importing it in global namespace is not allowed +# (2) To ensure the list of supported extensions is deterministic +AUDIO_EXTENSIONS = [ + "aiff", + "au", + "avr", + "caf", + "flac", + "htk", + "svx", + "mat4", + "mat5", + "mpc2k", + "ogg", + "paf", + "pvf", + "raw", + "rf64", + "sd2", + "sds", + "ircam", + "voc", + "w64", + "wav", + "nist", + "wavex", + "wve", + "xi", + "mp3", + "opus", +] +WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS + + def text_loads(data: bytes): return data.decode("utf-8") diff --git a/tests/packaged_modules/test_webdataset.py b/tests/packaged_modules/test_webdataset.py index c9b032dfa51..39963122809 100644 --- a/tests/packaged_modules/test_webdataset.py +++ b/tests/packaged_modules/test_webdataset.py @@ -1,45 +1,67 @@ +import json import tarfile +import numpy as np import pytest -from datasets import DownloadManager, Features, Image, Value +from datasets import Audio, DownloadManager, Features, Image, Value from datasets.packaged_modules.webdataset.webdataset import WebDataset -from ..utils import require_pil +from ..utils import require_pil, require_sndfile @pytest.fixture -def tar_file(tmp_path, image_file, text_file): +def image_wds_file(tmp_path, image_file): + json_file = tmp_path / "data.json" filename = tmp_path / "file.tar" num_examples = 3 + with json_file.open("w", encoding="utf-8") as f: + f.write(json.dumps({"caption": "this is an image"})) with tarfile.open(str(filename), "w") as f: for example_idx in range(num_examples): - f.add(text_file, f"{example_idx:05d}.txt") + f.add(json_file, f"{example_idx:05d}.json") f.add(image_file, f"{example_idx:05d}.jpg") return str(filename) @pytest.fixture -def bad_tar_file(tmp_path, image_file, text_file): +def audio_wds_file(tmp_path, audio_file): + json_file = tmp_path / "data.json" + filename = tmp_path / "file.tar" + num_examples = 3 + with json_file.open("w", encoding="utf-8") as f: + f.write(json.dumps({"transcript": "this is a transcript"})) + with tarfile.open(str(filename), "w") as f: + for example_idx in range(num_examples): + f.add(json_file, f"{example_idx:05d}.json") + f.add(audio_file, f"{example_idx:05d}.wav") + return str(filename) + + +@pytest.fixture +def bad_wds_file(tmp_path, image_file, text_file): + json_file = tmp_path / "data.json" filename = tmp_path / "bad_file.tar" + with json_file.open("w", encoding="utf-8") as f: + f.write(json.dumps({"caption": "this is an image"})) with tarfile.open(str(filename), "w") as f: f.add(image_file) - f.add(text_file) + f.add(json_file) return str(filename) @require_pil -def test_webdataset(tar_file): +def test_image_webdataset(image_wds_file): import PIL.Image - data_files = {"train": [tar_file]} + data_files = {"train": [image_wds_file]} webdataset = WebDataset(data_files=data_files) split_generators = webdataset._split_generators(DownloadManager()) assert webdataset.info.features == Features( { "__key__": Value("string"), "__url__": Value("string"), - "txt": Value("string"), + "json": {"caption": Value("string")}, "jpg": Image(), } ) @@ -49,15 +71,77 @@ def test_webdataset(tar_file): generator = webdataset._generate_examples(**split_generator.gen_kwargs) _, examples = zip(*generator) assert len(examples) == 3 - assert isinstance(examples[0]["txt"], str) + assert isinstance(examples[0]["json"], dict) + assert isinstance(examples[0]["json"]["caption"], str) assert isinstance(examples[0]["jpg"], dict) # keep encoded to avoid unecessary copies - decoded = webdataset.info.features.decode_example(examples[0]) - assert isinstance(decoded["txt"], str) + encoded = webdataset.info.features.encode_example(examples[0]) + decoded = webdataset.info.features.decode_example(encoded) + assert isinstance(decoded["json"], dict) + assert isinstance(decoded["json"]["caption"], str) assert isinstance(decoded["jpg"], PIL.Image.Image) -def test_webdataset_errors_on_bad_file(bad_tar_file): - data_files = {"train": [bad_tar_file]} +@require_sndfile +def test_audio_webdataset(audio_wds_file): + data_files = {"train": [audio_wds_file]} + webdataset = WebDataset(data_files=data_files) + split_generators = webdataset._split_generators(DownloadManager()) + assert webdataset.info.features == Features( + { + "__key__": Value("string"), + "__url__": Value("string"), + "json": {"transcript": Value("string")}, + "wav": Audio(), + } + ) + assert len(split_generators) == 1 + split_generator = split_generators[0] + assert split_generator.name == "train" + generator = webdataset._generate_examples(**split_generator.gen_kwargs) + _, examples = zip(*generator) + assert len(examples) == 3 + assert isinstance(examples[0]["json"], dict) + assert isinstance(examples[0]["json"]["transcript"], str) + assert isinstance(examples[0]["wav"], dict) + assert isinstance(examples[0]["wav"]["bytes"], bytes) # keep encoded to avoid unecessary copies + encoded = webdataset.info.features.encode_example(examples[0]) + decoded = webdataset.info.features.decode_example(encoded) + assert isinstance(decoded["json"], dict) + assert isinstance(decoded["json"]["transcript"], str) + assert isinstance(decoded["wav"], dict) + assert isinstance(decoded["wav"]["array"], np.ndarray) + + +def test_webdataset_errors_on_bad_file(bad_wds_file): + data_files = {"train": [bad_wds_file]} webdataset = WebDataset(data_files=data_files) with pytest.raises(ValueError): webdataset._split_generators(DownloadManager()) + + +@require_pil +def test_webdataset_with_features(image_wds_file): + import PIL.Image + + data_files = {"train": [image_wds_file]} + features = Features( + { + "__key__": Value("string"), + "__url__": Value("string"), + "json": {"caption": Value("string"), "additional_field": Value("int64")}, + "jpg": Image(), + } + ) + webdataset = WebDataset(data_files=data_files, features=features) + split_generators = webdataset._split_generators(DownloadManager()) + assert webdataset.info.features == features + split_generator = split_generators[0] + assert split_generator.name == "train" + generator = webdataset._generate_examples(**split_generator.gen_kwargs) + _, example = next(iter(generator)) + encoded = webdataset.info.features.encode_example(example) + decoded = webdataset.info.features.decode_example(encoded) + assert decoded["json"]["additional_field"] is None + assert isinstance(decoded["json"], dict) + assert isinstance(decoded["json"]["caption"], str) + assert isinstance(decoded["jpg"], PIL.Image.Image)