Skip to content

Commit ebf77d7

Browse files
author
polinaeterna
committed
fix tests: catch warnings with a pytest context manager
1 parent b4a8793 commit ebf77d7

File tree

1 file changed

+18
-35
lines changed

1 file changed

+18
-35
lines changed

tests/features/test_audio.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -146,22 +146,21 @@ def test_audio_decode_example_mp3(shared_datadir):
146146
@pytest.mark.torchaudio_latest
147147
@require_torchaudio_latest
148148
@pytest.mark.parametrize("torchaudio_failed", [False, True])
149-
def test_audio_decode_example_mp3_torchaudio_latest(shared_datadir, torchaudio_failed, caplog):
149+
def test_audio_decode_example_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
150150
audio_path = str(shared_datadir / "test_audio_44100.mp3")
151151
audio = Audio()
152152

153-
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock:
153+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
154+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
155+
) if torchaudio_failed else nullcontext():
156+
154157
if torchaudio_failed:
155158
load_mock.side_effect = RuntimeError()
156159

157160
decoded_example = audio.decode_example(audio.encode_example(audio_path))
158161
assert decoded_example["path"] == audio_path
159162
assert decoded_example["array"].shape == (110592,)
160163
assert decoded_example["sampling_rate"] == 44100
161-
warning_in_logs = sum(
162-
"decoding mp3 with `librosa` instead of `torchaudio`" in record.msg.lower() for record in caplog.records
163-
)
164-
assert warning_in_logs == 1 if torchaudio_failed else warning_in_logs == 0
165164

166165

167166
@require_libsndfile_with_opus
@@ -210,13 +209,15 @@ def test_audio_resampling_mp3_different_sampling_rates(shared_datadir):
210209
@pytest.mark.torchaudio_latest
211210
@require_torchaudio_latest
212211
@pytest.mark.parametrize("torchaudio_failed", [False, True])
213-
def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest(shared_datadir, torchaudio_failed, caplog):
212+
def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest(shared_datadir, torchaudio_failed):
214213
audio_path = str(shared_datadir / "test_audio_44100.mp3")
215214
audio_path2 = str(shared_datadir / "test_audio_16000.mp3")
216215
audio = Audio(sampling_rate=48000)
217216

218217
# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
219-
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock:
218+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
219+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
220+
) if torchaudio_failed else nullcontext():
220221
if torchaudio_failed:
221222
load_mock.side_effect = RuntimeError()
222223

@@ -232,12 +233,6 @@ def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest(shared_
232233
assert decoded_example["array"].shape == (122688,)
233234
assert decoded_example["sampling_rate"] == 48000
234235

235-
# we get warnings at each decoding
236-
warnings_in_logs = sum(
237-
"decoding mp3 with `librosa` instead of `torchaudio`" in record.msg.lower() for record in caplog.records
238-
)
239-
assert warnings_in_logs == 2 if torchaudio_failed else warnings_in_logs == 0
240-
241236

242237
@require_sndfile
243238
def test_dataset_with_audio_feature(shared_datadir):
@@ -329,7 +324,7 @@ def test_dataset_with_audio_feature_tar_mp3(tar_mp3_path):
329324

330325
@pytest.mark.torchaudio_latest
331326
@require_torchaudio_latest
332-
def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest(tar_mp3_path, caplog):
327+
def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest(tar_mp3_path):
333328
# no test for librosa here because it doesn't support file-like objects, only paths
334329
audio_filename = "test_audio_44100.mp3"
335330
data = {"audio": []}
@@ -451,16 +446,16 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
451446
@pytest.mark.torchaudio_latest
452447
@require_torchaudio_latest
453448
@pytest.mark.parametrize("torchaudio_failed", [False, True])
454-
def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest(
455-
shared_datadir, torchaudio_failed, caplog
456-
):
449+
def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
457450
audio_path = str(shared_datadir / "test_audio_44100.mp3")
458451
data = {"audio": [audio_path]}
459452
features = Features({"audio": Audio(sampling_rate=16000)})
460453
dset = Dataset.from_dict(data, features=features)
461454

462455
# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
463-
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock:
456+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
457+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
458+
) if torchaudio_failed else nullcontext():
464459
if torchaudio_failed:
465460
load_mock.side_effect = RuntimeError()
466461

@@ -484,12 +479,6 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest(
484479
assert column[0]["array"].shape == (40125,)
485480
assert column[0]["sampling_rate"] == 16000
486481

487-
# we get warnings at each decoding
488-
warnings_in_logs = sum(
489-
"decoding mp3 with `librosa` instead of `torchaudio`" in record.msg.lower() for record in caplog.records
490-
)
491-
assert warnings_in_logs == 3 if torchaudio_failed else warnings_in_logs == 0
492-
493482

494483
@require_sndfile
495484
def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
@@ -555,16 +544,16 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir)
555544
@pytest.mark.torchaudio_latest
556545
@require_torchaudio_latest
557546
@pytest.mark.parametrize("torchaudio_failed", [False, True])
558-
def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest(
559-
shared_datadir, torchaudio_failed, caplog
560-
):
547+
def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
561548
audio_path = str(shared_datadir / "test_audio_44100.mp3")
562549
data = {"audio": [audio_path]}
563550
features = Features({"audio": Audio()})
564551
dset = Dataset.from_dict(data, features=features)
565552

566553
# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
567-
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock:
554+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
555+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
556+
) if torchaudio_failed else nullcontext():
568557
if torchaudio_failed:
569558
load_mock.side_effect = RuntimeError()
570559

@@ -591,12 +580,6 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_late
591580
assert column[0]["array"].shape == (40125,)
592581
assert column[0]["sampling_rate"] == 16000
593582

594-
# we get warnings at each decoding
595-
warnings_in_logs = sum(
596-
"decoding mp3 with `librosa` instead of `torchaudio`" in record.msg.lower() for record in caplog.records
597-
)
598-
assert warnings_in_logs == 4 if torchaudio_failed else warnings_in_logs == 0
599-
600583

601584
@pytest.mark.parametrize(
602585
"build_data",

0 commit comments

Comments
 (0)