Skip to content

Commit 4a5b7d9

Browse files
authored
[WebDataset] Audio support and bug fixes (#6573)
audio support and bug fix
1 parent 999790b commit 4a5b7d9

File tree

3 files changed

+183
-41
lines changed

3 files changed

+183
-41
lines changed

src/datasets/features/features.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,10 +1240,7 @@ def encode_nested_example(schema, obj, level=0):
12401240
if level == 0 and obj is None:
12411241
raise ValueError("Got None but expected a dictionary instead")
12421242
return (
1243-
{
1244-
k: encode_nested_example(sub_schema, sub_obj, level=level + 1)
1245-
for k, (sub_schema, sub_obj) in zip_dict(schema, obj)
1246-
}
1243+
{k: encode_nested_example(schema[k], obj.get(k), level=level + 1) for k in schema}
12471244
if obj is not None
12481245
else None
12491246
)
@@ -1269,13 +1266,17 @@ def encode_nested_example(schema, obj, level=0):
12691266
list_dict = {}
12701267
if isinstance(obj, (list, tuple)):
12711268
# obj is a list of dict
1272-
for k, dict_tuples in zip_dict(schema.feature, *obj):
1273-
list_dict[k] = [encode_nested_example(dict_tuples[0], o, level=level + 1) for o in dict_tuples[1:]]
1269+
for k in schema.feature:
1270+
list_dict[k] = [encode_nested_example(schema.feature[k], o.get(k), level=level + 1) for o in obj]
12741271
return list_dict
12751272
else:
12761273
# obj is a single dict
1277-
for k, (sub_schema, sub_objs) in zip_dict(schema.feature, obj):
1278-
list_dict[k] = [encode_nested_example(sub_schema, o, level=level + 1) for o in sub_objs]
1274+
for k in schema.feature:
1275+
list_dict[k] = (
1276+
[encode_nested_example(schema.feature[k], o, level=level + 1) for o in obj[k]]
1277+
if k in obj
1278+
else None
1279+
)
12791280
return list_dict
12801281
# schema.feature is not a dict
12811282
if isinstance(obj, str): # don't interpret a string as a list

src/datasets/packaged_modules/webdataset/webdataset.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class WebDataset(datasets.GeneratorBasedBuilder):
1616
DEFAULT_WRITER_BATCH_SIZE = 100
1717
IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script
18+
AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script
1819
DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script
1920
NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5
2021

@@ -65,34 +66,46 @@ def _split_generators(self, dl_manager):
6566
name=split_name, gen_kwargs={"tar_paths": tar_paths, "tar_iterators": tar_iterators}
6667
)
6768
)
68-
69-
# Get one example to get the feature types
70-
pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0])
71-
first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE))
72-
if any(example.keys() != first_examples[0].keys() for example in first_examples):
73-
raise ValueError(
74-
"The TAR archives of the dataset should be in WebDataset format, "
75-
"but the files in the archive don't share the same prefix or the same types."
76-
)
77-
inferred_arrow_schema = pa.Table.from_pylist(first_examples[:1]).schema
78-
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)
79-
80-
# Set Image types
81-
for field_name in first_examples[0]:
82-
extension = field_name.rsplit(".", 1)[-1]
83-
if extension in self.IMAGE_EXTENSIONS:
84-
features[field_name] = datasets.Image()
85-
self.info.features = features
69+
if not self.info.features:
70+
# Get one example to get the feature types
71+
pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0])
72+
first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE))
73+
if any(example.keys() != first_examples[0].keys() for example in first_examples):
74+
raise ValueError(
75+
"The TAR archives of the dataset should be in WebDataset format, "
76+
"but the files in the archive don't share the same prefix or the same types."
77+
)
78+
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
79+
if datasets.config.PYARROW_VERSION.major < 14:
80+
inferred_arrow_schema = pa.concat_tables(pa_tables, promote=True).schema
81+
else:
82+
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
83+
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)
84+
85+
# Set Image types
86+
for field_name in first_examples[0]:
87+
extension = field_name.rsplit(".", 1)[-1]
88+
if extension in self.IMAGE_EXTENSIONS:
89+
features[field_name] = datasets.Image()
90+
# Set Audio types
91+
for field_name in first_examples[0]:
92+
extension = field_name.rsplit(".", 1)[-1]
93+
if extension in self.AUDIO_EXTENSIONS:
94+
features[field_name] = datasets.Audio()
95+
self.info.features = features
8696

8797
return splits
8898

8999
def _generate_examples(self, tar_paths, tar_iterators):
90100
image_field_names = [
91101
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image)
92102
]
103+
audio_field_names = [
104+
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio)
105+
]
93106
for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)):
94107
for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)):
95-
for field_name in image_field_names:
108+
for field_name in image_field_names + audio_field_names:
96109
example[field_name] = {"path": example["__key__"] + "." + field_name, "bytes": example[field_name]}
97110
yield f"{tar_idx}_{example_idx}", example
98111

@@ -177,6 +190,50 @@ def _generate_examples(self, tar_paths, tar_iterators):
177190
WebDataset.IMAGE_EXTENSIONS = IMAGE_EXTENSIONS
178191

179192

193+
# Obtained with:
194+
# ```
195+
# import soundfile as sf
196+
#
197+
# AUDIO_EXTENSIONS = [f".{format.lower()}" for format in sf.available_formats().keys()]
198+
#
199+
# # .mp3 is currently decoded via `torchaudio`, .opus decoding is supported if version of `libsndfile` >= 1.0.30:
200+
# AUDIO_EXTENSIONS.extend([".mp3", ".opus"])
201+
# ```
202+
# We intentionally do not run this code on launch because:
203+
# (1) Soundfile is an optional dependency, so importing it in global namespace is not allowed
204+
# (2) To ensure the list of supported extensions is deterministic
205+
AUDIO_EXTENSIONS = [
206+
"aiff",
207+
"au",
208+
"avr",
209+
"caf",
210+
"flac",
211+
"htk",
212+
"svx",
213+
"mat4",
214+
"mat5",
215+
"mpc2k",
216+
"ogg",
217+
"paf",
218+
"pvf",
219+
"raw",
220+
"rf64",
221+
"sd2",
222+
"sds",
223+
"ircam",
224+
"voc",
225+
"w64",
226+
"wav",
227+
"nist",
228+
"wavex",
229+
"wve",
230+
"xi",
231+
"mp3",
232+
"opus",
233+
]
234+
WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS
235+
236+
180237
def text_loads(data: bytes):
181238
return data.decode("utf-8")
182239

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,67 @@
1+
import json
12
import tarfile
23

4+
import numpy as np
35
import pytest
46

5-
from datasets import DownloadManager, Features, Image, Value
7+
from datasets import Audio, DownloadManager, Features, Image, Value
68
from datasets.packaged_modules.webdataset.webdataset import WebDataset
79

8-
from ..utils import require_pil
10+
from ..utils import require_pil, require_sndfile
911

1012

1113
@pytest.fixture
12-
def tar_file(tmp_path, image_file, text_file):
14+
def image_wds_file(tmp_path, image_file):
15+
json_file = tmp_path / "data.json"
1316
filename = tmp_path / "file.tar"
1417
num_examples = 3
18+
with json_file.open("w", encoding="utf-8") as f:
19+
f.write(json.dumps({"caption": "this is an image"}))
1520
with tarfile.open(str(filename), "w") as f:
1621
for example_idx in range(num_examples):
17-
f.add(text_file, f"{example_idx:05d}.txt")
22+
f.add(json_file, f"{example_idx:05d}.json")
1823
f.add(image_file, f"{example_idx:05d}.jpg")
1924
return str(filename)
2025

2126

2227
@pytest.fixture
23-
def bad_tar_file(tmp_path, image_file, text_file):
28+
def audio_wds_file(tmp_path, audio_file):
29+
json_file = tmp_path / "data.json"
30+
filename = tmp_path / "file.tar"
31+
num_examples = 3
32+
with json_file.open("w", encoding="utf-8") as f:
33+
f.write(json.dumps({"transcript": "this is a transcript"}))
34+
with tarfile.open(str(filename), "w") as f:
35+
for example_idx in range(num_examples):
36+
f.add(json_file, f"{example_idx:05d}.json")
37+
f.add(audio_file, f"{example_idx:05d}.wav")
38+
return str(filename)
39+
40+
41+
@pytest.fixture
42+
def bad_wds_file(tmp_path, image_file, text_file):
43+
json_file = tmp_path / "data.json"
2444
filename = tmp_path / "bad_file.tar"
45+
with json_file.open("w", encoding="utf-8") as f:
46+
f.write(json.dumps({"caption": "this is an image"}))
2547
with tarfile.open(str(filename), "w") as f:
2648
f.add(image_file)
27-
f.add(text_file)
49+
f.add(json_file)
2850
return str(filename)
2951

3052

3153
@require_pil
32-
def test_webdataset(tar_file):
54+
def test_image_webdataset(image_wds_file):
3355
import PIL.Image
3456

35-
data_files = {"train": [tar_file]}
57+
data_files = {"train": [image_wds_file]}
3658
webdataset = WebDataset(data_files=data_files)
3759
split_generators = webdataset._split_generators(DownloadManager())
3860
assert webdataset.info.features == Features(
3961
{
4062
"__key__": Value("string"),
4163
"__url__": Value("string"),
42-
"txt": Value("string"),
64+
"json": {"caption": Value("string")},
4365
"jpg": Image(),
4466
}
4567
)
@@ -49,15 +71,77 @@ def test_webdataset(tar_file):
4971
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
5072
_, examples = zip(*generator)
5173
assert len(examples) == 3
52-
assert isinstance(examples[0]["txt"], str)
74+
assert isinstance(examples[0]["json"], dict)
75+
assert isinstance(examples[0]["json"]["caption"], str)
5376
assert isinstance(examples[0]["jpg"], dict) # keep encoded to avoid unecessary copies
54-
decoded = webdataset.info.features.decode_example(examples[0])
55-
assert isinstance(decoded["txt"], str)
77+
encoded = webdataset.info.features.encode_example(examples[0])
78+
decoded = webdataset.info.features.decode_example(encoded)
79+
assert isinstance(decoded["json"], dict)
80+
assert isinstance(decoded["json"]["caption"], str)
5681
assert isinstance(decoded["jpg"], PIL.Image.Image)
5782

5883

59-
def test_webdataset_errors_on_bad_file(bad_tar_file):
60-
data_files = {"train": [bad_tar_file]}
84+
@require_sndfile
85+
def test_audio_webdataset(audio_wds_file):
86+
data_files = {"train": [audio_wds_file]}
87+
webdataset = WebDataset(data_files=data_files)
88+
split_generators = webdataset._split_generators(DownloadManager())
89+
assert webdataset.info.features == Features(
90+
{
91+
"__key__": Value("string"),
92+
"__url__": Value("string"),
93+
"json": {"transcript": Value("string")},
94+
"wav": Audio(),
95+
}
96+
)
97+
assert len(split_generators) == 1
98+
split_generator = split_generators[0]
99+
assert split_generator.name == "train"
100+
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
101+
_, examples = zip(*generator)
102+
assert len(examples) == 3
103+
assert isinstance(examples[0]["json"], dict)
104+
assert isinstance(examples[0]["json"]["transcript"], str)
105+
assert isinstance(examples[0]["wav"], dict)
106+
assert isinstance(examples[0]["wav"]["bytes"], bytes) # keep encoded to avoid unecessary copies
107+
encoded = webdataset.info.features.encode_example(examples[0])
108+
decoded = webdataset.info.features.decode_example(encoded)
109+
assert isinstance(decoded["json"], dict)
110+
assert isinstance(decoded["json"]["transcript"], str)
111+
assert isinstance(decoded["wav"], dict)
112+
assert isinstance(decoded["wav"]["array"], np.ndarray)
113+
114+
115+
def test_webdataset_errors_on_bad_file(bad_wds_file):
116+
data_files = {"train": [bad_wds_file]}
61117
webdataset = WebDataset(data_files=data_files)
62118
with pytest.raises(ValueError):
63119
webdataset._split_generators(DownloadManager())
120+
121+
122+
@require_pil
123+
def test_webdataset_with_features(image_wds_file):
124+
import PIL.Image
125+
126+
data_files = {"train": [image_wds_file]}
127+
features = Features(
128+
{
129+
"__key__": Value("string"),
130+
"__url__": Value("string"),
131+
"json": {"caption": Value("string"), "additional_field": Value("int64")},
132+
"jpg": Image(),
133+
}
134+
)
135+
webdataset = WebDataset(data_files=data_files, features=features)
136+
split_generators = webdataset._split_generators(DownloadManager())
137+
assert webdataset.info.features == features
138+
split_generator = split_generators[0]
139+
assert split_generator.name == "train"
140+
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
141+
_, example = next(iter(generator))
142+
encoded = webdataset.info.features.encode_example(example)
143+
decoded = webdataset.info.features.decode_example(encoded)
144+
assert decoded["json"]["additional_field"] is None
145+
assert isinstance(decoded["json"], dict)
146+
assert isinstance(decoded["json"]["caption"], str)
147+
assert isinstance(decoded["jpg"], PIL.Image.Image)

0 commit comments

Comments
 (0)