11import os
22import tarfile
3+ from contextlib import nullcontext
4+ from unittest .mock import patch
35
46import pyarrow as pa
57import pytest
68
79from datasets import Dataset , concatenate_datasets , load_dataset
810from datasets .features import Audio , Features , Sequence , Value
911
10- from ..utils import require_libsndfile_with_opus , require_sndfile , require_sox , require_torchaudio
12+ from ..utils import (
13+ require_libsndfile_with_opus ,
14+ require_sndfile ,
15+ require_sox ,
16+ require_torchaudio ,
17+ require_torchaudio_latest ,
18+ )
1119
1220
1321@pytest .fixture ()
@@ -135,6 +143,26 @@ def test_audio_decode_example_mp3(shared_datadir):
135143 assert decoded_example ["sampling_rate" ] == 44100
136144
137145
146+ @pytest .mark .torchaudio_latest
147+ @require_torchaudio_latest
148+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
149+ def test_audio_decode_example_mp3_torchaudio_latest (shared_datadir , torchaudio_failed ):
150+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
151+ audio = Audio ()
152+
153+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
154+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
155+ ) if torchaudio_failed else nullcontext ():
156+
157+ if torchaudio_failed :
158+ load_mock .side_effect = RuntimeError ()
159+
160+ decoded_example = audio .decode_example (audio .encode_example (audio_path ))
161+ assert decoded_example ["path" ] == audio_path
162+ assert decoded_example ["array" ].shape == (110592 ,)
163+ assert decoded_example ["sampling_rate" ] == 44100
164+
165+
138166@require_libsndfile_with_opus
139167def test_audio_decode_example_opus (shared_datadir ):
140168 audio_path = str (shared_datadir / "test_audio_48000.opus" )
@@ -178,6 +206,34 @@ def test_audio_resampling_mp3_different_sampling_rates(shared_datadir):
178206 assert decoded_example ["sampling_rate" ] == 48000
179207
180208
209+ @pytest .mark .torchaudio_latest
210+ @require_torchaudio_latest
211+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
212+ def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest (shared_datadir , torchaudio_failed ):
213+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
214+ audio_path2 = str (shared_datadir / "test_audio_16000.mp3" )
215+ audio = Audio (sampling_rate = 48000 )
216+
217+ # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
218+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
219+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
220+ ) if torchaudio_failed else nullcontext ():
221+ if torchaudio_failed :
222+ load_mock .side_effect = RuntimeError ()
223+
224+ decoded_example = audio .decode_example (audio .encode_example (audio_path ))
225+ assert decoded_example .keys () == {"path" , "array" , "sampling_rate" }
226+ assert decoded_example ["path" ] == audio_path
227+ assert decoded_example ["array" ].shape == (120373 ,)
228+ assert decoded_example ["sampling_rate" ] == 48000
229+
230+ decoded_example = audio .decode_example (audio .encode_example (audio_path2 ))
231+ assert decoded_example .keys () == {"path" , "array" , "sampling_rate" }
232+ assert decoded_example ["path" ] == audio_path2
233+ assert decoded_example ["array" ].shape == (122688 ,)
234+ assert decoded_example ["sampling_rate" ] == 48000
235+
236+
181237@require_sndfile
182238def test_dataset_with_audio_feature (shared_datadir ):
183239 audio_path = str (shared_datadir / "test_audio_44100.wav" )
@@ -266,6 +322,38 @@ def test_dataset_with_audio_feature_tar_mp3(tar_mp3_path):
266322 assert column [0 ]["sampling_rate" ] == 44100
267323
268324
325+ @pytest .mark .torchaudio_latest
326+ @require_torchaudio_latest
327+ def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest (tar_mp3_path ):
328+ # no test for librosa here because it doesn't support file-like objects, only paths
329+ audio_filename = "test_audio_44100.mp3"
330+ data = {"audio" : []}
331+ for file_path , file_obj in iter_archive (tar_mp3_path ):
332+ data ["audio" ].append ({"path" : file_path , "bytes" : file_obj .read ()})
333+ break
334+ features = Features ({"audio" : Audio ()})
335+ dset = Dataset .from_dict (data , features = features )
336+ item = dset [0 ]
337+ assert item .keys () == {"audio" }
338+ assert item ["audio" ].keys () == {"path" , "array" , "sampling_rate" }
339+ assert item ["audio" ]["path" ] == audio_filename
340+ assert item ["audio" ]["array" ].shape == (110592 ,)
341+ assert item ["audio" ]["sampling_rate" ] == 44100
342+ batch = dset [:1 ]
343+ assert batch .keys () == {"audio" }
344+ assert len (batch ["audio" ]) == 1
345+ assert batch ["audio" ][0 ].keys () == {"path" , "array" , "sampling_rate" }
346+ assert batch ["audio" ][0 ]["path" ] == audio_filename
347+ assert batch ["audio" ][0 ]["array" ].shape == (110592 ,)
348+ assert batch ["audio" ][0 ]["sampling_rate" ] == 44100
349+ column = dset ["audio" ]
350+ assert len (column ) == 1
351+ assert column [0 ].keys () == {"path" , "array" , "sampling_rate" }
352+ assert column [0 ]["path" ] == audio_filename
353+ assert column [0 ]["array" ].shape == (110592 ,)
354+ assert column [0 ]["sampling_rate" ] == 44100
355+
356+
269357@require_sndfile
270358def test_dataset_with_audio_feature_with_none ():
271359 data = {"audio" : [None ]}
@@ -328,7 +416,7 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir):
328416
329417
330418@require_sox
331- @require_sndfile
419+ @require_torchaudio
332420def test_resampling_at_loading_dataset_with_audio_feature_mp3 (shared_datadir ):
333421 audio_path = str (shared_datadir / "test_audio_44100.mp3" )
334422 data = {"audio" : [audio_path ]}
@@ -355,6 +443,43 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
355443 assert column [0 ]["sampling_rate" ] == 16000
356444
357445
446+ @pytest .mark .torchaudio_latest
447+ @require_torchaudio_latest
448+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
449+ def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest (shared_datadir , torchaudio_failed ):
450+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
451+ data = {"audio" : [audio_path ]}
452+ features = Features ({"audio" : Audio (sampling_rate = 16000 )})
453+ dset = Dataset .from_dict (data , features = features )
454+
455+ # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
456+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
457+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
458+ ) if torchaudio_failed else nullcontext ():
459+ if torchaudio_failed :
460+ load_mock .side_effect = RuntimeError ()
461+
462+ item = dset [0 ]
463+ assert item .keys () == {"audio" }
464+ assert item ["audio" ].keys () == {"path" , "array" , "sampling_rate" }
465+ assert item ["audio" ]["path" ] == audio_path
466+ assert item ["audio" ]["array" ].shape == (40125 ,)
467+ assert item ["audio" ]["sampling_rate" ] == 16000
468+ batch = dset [:1 ]
469+ assert batch .keys () == {"audio" }
470+ assert len (batch ["audio" ]) == 1
471+ assert batch ["audio" ][0 ].keys () == {"path" , "array" , "sampling_rate" }
472+ assert batch ["audio" ][0 ]["path" ] == audio_path
473+ assert batch ["audio" ][0 ]["array" ].shape == (40125 ,)
474+ assert batch ["audio" ][0 ]["sampling_rate" ] == 16000
475+ column = dset ["audio" ]
476+ assert len (column ) == 1
477+ assert column [0 ].keys () == {"path" , "array" , "sampling_rate" }
478+ assert column [0 ]["path" ] == audio_path
479+ assert column [0 ]["array" ].shape == (40125 ,)
480+ assert column [0 ]["sampling_rate" ] == 16000
481+
482+
358483@require_sndfile
359484def test_resampling_after_loading_dataset_with_audio_feature (shared_datadir ):
360485 audio_path = str (shared_datadir / "test_audio_44100.wav" )
@@ -386,7 +511,7 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
386511
387512
388513@require_sox
389- @require_sndfile
514+ @require_torchaudio
390515def test_resampling_after_loading_dataset_with_audio_feature_mp3 (shared_datadir ):
391516 audio_path = str (shared_datadir / "test_audio_44100.mp3" )
392517 data = {"audio" : [audio_path ]}
@@ -416,6 +541,46 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir)
416541 assert column [0 ]["sampling_rate" ] == 16000
417542
418543
544+ @pytest .mark .torchaudio_latest
545+ @require_torchaudio_latest
546+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
547+ def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest (shared_datadir , torchaudio_failed ):
548+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
549+ data = {"audio" : [audio_path ]}
550+ features = Features ({"audio" : Audio ()})
551+ dset = Dataset .from_dict (data , features = features )
552+
553+ # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
554+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
555+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
556+ ) if torchaudio_failed else nullcontext ():
557+ if torchaudio_failed :
558+ load_mock .side_effect = RuntimeError ()
559+
560+ item = dset [0 ]
561+ assert item ["audio" ]["sampling_rate" ] == 44100
562+ dset = dset .cast_column ("audio" , Audio (sampling_rate = 16000 ))
563+ item = dset [0 ]
564+ assert item .keys () == {"audio" }
565+ assert item ["audio" ].keys () == {"path" , "array" , "sampling_rate" }
566+ assert item ["audio" ]["path" ] == audio_path
567+ assert item ["audio" ]["array" ].shape == (40125 ,)
568+ assert item ["audio" ]["sampling_rate" ] == 16000
569+ batch = dset [:1 ]
570+ assert batch .keys () == {"audio" }
571+ assert len (batch ["audio" ]) == 1
572+ assert batch ["audio" ][0 ].keys () == {"path" , "array" , "sampling_rate" }
573+ assert batch ["audio" ][0 ]["path" ] == audio_path
574+ assert batch ["audio" ][0 ]["array" ].shape == (40125 ,)
575+ assert batch ["audio" ][0 ]["sampling_rate" ] == 16000
576+ column = dset ["audio" ]
577+ assert len (column ) == 1
578+ assert column [0 ].keys () == {"path" , "array" , "sampling_rate" }
579+ assert column [0 ]["path" ] == audio_path
580+ assert column [0 ]["array" ].shape == (40125 ,)
581+ assert column [0 ]["sampling_rate" ] == 16000
582+
583+
419584@pytest .mark .parametrize (
420585 "build_data" ,
421586 [
0 commit comments