1+ import json
12import tarfile
23
4+ import numpy as np
35import pytest
46
5- from datasets import DownloadManager , Features , Image , Value
7+ from datasets import Audio , DownloadManager , Features , Image , Value
68from datasets .packaged_modules .webdataset .webdataset import WebDataset
79
8- from ..utils import require_pil
10+ from ..utils import require_pil , require_sndfile
911
1012
1113@pytest .fixture
12- def tar_file (tmp_path , image_file , text_file ):
14+ def image_wds_file (tmp_path , image_file ):
15+ json_file = tmp_path / "data.json"
1316 filename = tmp_path / "file.tar"
1417 num_examples = 3
18+ with json_file .open ("w" , encoding = "utf-8" ) as f :
19+ f .write (json .dumps ({"caption" : "this is an image" }))
1520 with tarfile .open (str (filename ), "w" ) as f :
1621 for example_idx in range (num_examples ):
17- f .add (text_file , f"{ example_idx :05d} .txt " )
22+ f .add (json_file , f"{ example_idx :05d} .json " )
1823 f .add (image_file , f"{ example_idx :05d} .jpg" )
1924 return str (filename )
2025
2126
2227@pytest .fixture
23- def bad_tar_file (tmp_path , image_file , text_file ):
28+ def audio_wds_file (tmp_path , audio_file ):
29+ json_file = tmp_path / "data.json"
30+ filename = tmp_path / "file.tar"
31+ num_examples = 3
32+ with json_file .open ("w" , encoding = "utf-8" ) as f :
33+ f .write (json .dumps ({"transcript" : "this is a transcript" }))
34+ with tarfile .open (str (filename ), "w" ) as f :
35+ for example_idx in range (num_examples ):
36+ f .add (json_file , f"{ example_idx :05d} .json" )
37+ f .add (audio_file , f"{ example_idx :05d} .wav" )
38+ return str (filename )
39+
40+
41+ @pytest .fixture
42+ def bad_wds_file (tmp_path , image_file , text_file ):
43+ json_file = tmp_path / "data.json"
2444 filename = tmp_path / "bad_file.tar"
45+ with json_file .open ("w" , encoding = "utf-8" ) as f :
46+ f .write (json .dumps ({"caption" : "this is an image" }))
2547 with tarfile .open (str (filename ), "w" ) as f :
2648 f .add (image_file )
27- f .add (text_file )
49+ f .add (json_file )
2850 return str (filename )
2951
3052
3153@require_pil
32- def test_webdataset ( tar_file ):
54+ def test_image_webdataset ( image_wds_file ):
3355 import PIL .Image
3456
35- data_files = {"train" : [tar_file ]}
57+ data_files = {"train" : [image_wds_file ]}
3658 webdataset = WebDataset (data_files = data_files )
3759 split_generators = webdataset ._split_generators (DownloadManager ())
3860 assert webdataset .info .features == Features (
3961 {
4062 "__key__" : Value ("string" ),
4163 "__url__" : Value ("string" ),
42- "txt " : Value ("string" ),
64+ "json " : { "caption" : Value ("string" )} ,
4365 "jpg" : Image (),
4466 }
4567 )
@@ -49,15 +71,77 @@ def test_webdataset(tar_file):
4971 generator = webdataset ._generate_examples (** split_generator .gen_kwargs )
5072 _ , examples = zip (* generator )
5173 assert len (examples ) == 3
52- assert isinstance (examples [0 ]["txt" ], str )
74+ assert isinstance (examples [0 ]["json" ], dict )
75+ assert isinstance (examples [0 ]["json" ]["caption" ], str )
5376 assert isinstance (examples [0 ]["jpg" ], dict ) # keep encoded to avoid unecessary copies
54- decoded = webdataset .info .features .decode_example (examples [0 ])
55- assert isinstance (decoded ["txt" ], str )
77+ encoded = webdataset .info .features .encode_example (examples [0 ])
78+ decoded = webdataset .info .features .decode_example (encoded )
79+ assert isinstance (decoded ["json" ], dict )
80+ assert isinstance (decoded ["json" ]["caption" ], str )
5681 assert isinstance (decoded ["jpg" ], PIL .Image .Image )
5782
5883
59- def test_webdataset_errors_on_bad_file (bad_tar_file ):
60- data_files = {"train" : [bad_tar_file ]}
84+ @require_sndfile
85+ def test_audio_webdataset (audio_wds_file ):
86+ data_files = {"train" : [audio_wds_file ]}
87+ webdataset = WebDataset (data_files = data_files )
88+ split_generators = webdataset ._split_generators (DownloadManager ())
89+ assert webdataset .info .features == Features (
90+ {
91+ "__key__" : Value ("string" ),
92+ "__url__" : Value ("string" ),
93+ "json" : {"transcript" : Value ("string" )},
94+ "wav" : Audio (),
95+ }
96+ )
97+ assert len (split_generators ) == 1
98+ split_generator = split_generators [0 ]
99+ assert split_generator .name == "train"
100+ generator = webdataset ._generate_examples (** split_generator .gen_kwargs )
101+ _ , examples = zip (* generator )
102+ assert len (examples ) == 3
103+ assert isinstance (examples [0 ]["json" ], dict )
104+ assert isinstance (examples [0 ]["json" ]["transcript" ], str )
105+ assert isinstance (examples [0 ]["wav" ], dict )
106+ assert isinstance (examples [0 ]["wav" ]["bytes" ], bytes ) # keep encoded to avoid unecessary copies
107+ encoded = webdataset .info .features .encode_example (examples [0 ])
108+ decoded = webdataset .info .features .decode_example (encoded )
109+ assert isinstance (decoded ["json" ], dict )
110+ assert isinstance (decoded ["json" ]["transcript" ], str )
111+ assert isinstance (decoded ["wav" ], dict )
112+ assert isinstance (decoded ["wav" ]["array" ], np .ndarray )
113+
114+
115+ def test_webdataset_errors_on_bad_file (bad_wds_file ):
116+ data_files = {"train" : [bad_wds_file ]}
61117 webdataset = WebDataset (data_files = data_files )
62118 with pytest .raises (ValueError ):
63119 webdataset ._split_generators (DownloadManager ())
120+
121+
122+ @require_pil
123+ def test_webdataset_with_features (image_wds_file ):
124+ import PIL .Image
125+
126+ data_files = {"train" : [image_wds_file ]}
127+ features = Features (
128+ {
129+ "__key__" : Value ("string" ),
130+ "__url__" : Value ("string" ),
131+ "json" : {"caption" : Value ("string" ), "additional_field" : Value ("int64" )},
132+ "jpg" : Image (),
133+ }
134+ )
135+ webdataset = WebDataset (data_files = data_files , features = features )
136+ split_generators = webdataset ._split_generators (DownloadManager ())
137+ assert webdataset .info .features == features
138+ split_generator = split_generators [0 ]
139+ assert split_generator .name == "train"
140+ generator = webdataset ._generate_examples (** split_generator .gen_kwargs )
141+ _ , example = next (iter (generator ))
142+ encoded = webdataset .info .features .encode_example (example )
143+ decoded = webdataset .info .features .decode_example (encoded )
144+ assert decoded ["json" ]["additional_field" ] is None
145+ assert isinstance (decoded ["json" ], dict )
146+ assert isinstance (decoded ["json" ]["caption" ], str )
147+ assert isinstance (decoded ["jpg" ], PIL .Image .Image )
0 commit comments