Skip to content

Commit 97d6582

Browse files
ProGamerGovlhoestq
authored andcommitted
Fix WebDatasets KeyError for user-defined Features when a field is missing in an example (#7004)
* Fix KeyError bug * Add additional check Co-authored-by: Quentin Lhoest <[email protected]> * Add test for missing key handling * update test --------- Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent cc5ae64 commit 97d6582

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/datasets/packaged_modules/webdataset/webdataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,18 @@ def _generate_examples(self, tar_paths, tar_iterators):
109109
audio_field_names = [
110110
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio)
111111
]
112+
all_field_names = list(self.info.features.keys())
112113
for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)):
113114
for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)):
115+
for field_name in all_field_names:
116+
if field_name not in example:
117+
example[field_name] = None
114118
for field_name in image_field_names + audio_field_names:
115-
example[field_name] = {"path": example["__key__"] + "." + field_name, "bytes": example[field_name]}
119+
if example[field_name] is not None:
120+
example[field_name] = {
121+
"path": example["__key__"] + "." + field_name,
122+
"bytes": example[field_name],
123+
}
116124
yield f"{tar_idx}_{example_idx}", example
117125

118126

tests/packaged_modules/test_webdataset.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,37 @@ def test_image_webdataset(image_wds_file):
128128
assert isinstance(decoded["jpg"], PIL.Image.Image)
129129

130130

131+
@require_pil
132+
def test_image_webdataset_missing_keys(image_wds_file):
133+
import PIL.Image
134+
135+
data_files = {"train": [image_wds_file]}
136+
features = Features(
137+
{
138+
"__key__": Value("string"),
139+
"__url__": Value("string"),
140+
"json": {"caption": Value("string")},
141+
"jpg": Image(),
142+
"jpeg": Image(), # additional field
143+
"txt": Value("string"), # additional field
144+
}
145+
)
146+
webdataset = WebDataset(data_files=data_files, features=features)
147+
split_generators = webdataset._split_generators(DownloadManager())
148+
assert webdataset.info.features == features
149+
split_generator = split_generators[0]
150+
assert split_generator.name == "train"
151+
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
152+
_, example = next(iter(generator))
153+
encoded = webdataset.info.features.encode_example(example)
154+
decoded = webdataset.info.features.decode_example(encoded)
155+
assert isinstance(decoded["json"], dict)
156+
assert isinstance(decoded["json"]["caption"], str)
157+
assert isinstance(decoded["jpg"], PIL.Image.Image)
158+
assert decoded["jpeg"] is None
159+
assert decoded["txt"] is None
160+
161+
131162
@require_sndfile
132163
def test_audio_webdataset(audio_wds_file):
133164
data_files = {"train": [audio_wds_file]}

0 commit comments

Comments
 (0)