Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
37d9eb2
decode mp3 with librosa if torchaudio is > 0.12 (ideally version of f…
Aug 31, 2022
48991d4
add flake8 ignore for unused imports
Sep 1, 2022
4114bd5
improve error mesasage
Sep 1, 2022
e924ca1
Merge branch 'huggingface:main' into workaround-torchaudio-0.12
Sep 1, 2022
2e12525
Merge branch 'huggingface:main' into workaround-torchaudio-0.12
Sep 7, 2022
f7dbaa8
Merge branch 'huggingface:main' into workaround-torchaudio-0.12
Sep 13, 2022
bd7a1ef
decode mp3 with torchaudio>=0.12 if it works (instead of librosa)
Sep 13, 2022
f7afaff
fix last commit
Sep 13, 2022
4a9da10
fix warnings
Sep 13, 2022
7fca41f
use datasets logging instead of standard
Sep 13, 2022
1305be2
fix incorrect marks for mp3 tests (require torchaudio, not sndfile)
Sep 14, 2022
05d400a
add tests for latest torchaudio + separate stage in CI for it (first …
Sep 14, 2022
6e82a88
get back unintantionally removed require_sox for mp3 tests
Sep 14, 2022
883b6f5
install ffmpeg in CI env to test torchaudio
Sep 15, 2022
d5f06ed
test CI again...
Sep 15, 2022
4fbaeb6
fix pip uninstall - add missing -y param
Sep 15, 2022
2ee379f
install ffmpeg only on ubuntu
Sep 15, 2022
400d8f1
try to compile old version of ffmpeg to test librosa mp3 loading
Sep 15, 2022
942d396
fix ci.yml
Sep 15, 2022
b0f8bb7
add some option to configure of old ffmpeg (i have no idea what it me…
Sep 15, 2022
6d4d9a2
use mock to emulate torchaudio fail, add tests for librosa (not all o…
Sep 15, 2022
3f3cb53
add missing | in ci run
Sep 15, 2022
0748cc4
Merge branch 'huggingface:main' into workaround-torchaudio-0.12
Sep 19, 2022
1cae272
try to skip test if ffmpeg not installed
Sep 19, 2022
4665dcf
remove ffmpeg version checking on windows ci
Sep 19, 2022
7960700
test torchaudio_latest only on ubuntu
Sep 19, 2022
5f0efed
refactor test for latest torchaudio
Sep 19, 2022
bfe1be7
try/except decoding with librosa for file-like objects
Sep 19, 2022
9dd632d
more tests for latest torchaudio, should be comlpete set now
Sep 19, 2022
852176c
remove unused decorator for ffmpeg checking
Sep 19, 2022
557a9cb
refactor ci workflow (first install, then test)
Sep 19, 2022
3f63882
get back full library testing
Sep 19, 2022
a0672cc
Update tests/utils.py
Sep 20, 2022
420f63d
replace logging with warnings
Sep 20, 2022
b4a8793
Merge branch 'workaround-torchaudio-0.12' of github.com:polinaeterna/…
Sep 20, 2022
c016395
Merge branch 'huggingface:main' into workaround-torchaudio-0.12
Sep 20, 2022
ebf77d7
fix tests: catch warnings with a pytest context manager
Sep 20, 2022
ae67712
Merge branch 'workaround-torchaudio-0.12' of github.com:polinaeterna/…
Sep 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ jobs:
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
- name: Install dependencies to test torchaudio>=0.12 on Ubuntu
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
pip uninstall -y torchaudio torch
pip install "torchaudio>=0.12"
sudo apt-get -y install ffmpeg
- name: Test torchaudio>=0.12 on Ubuntu
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
python -m pytest -rfExX -m torchaudio_latest -n 2 --dist loadfile -sv ./tests/features/test_audio.py
58 changes: 45 additions & 13 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from dataclasses import dataclass, field
from io import BytesIO
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union
Expand Down Expand Up @@ -268,7 +269,7 @@ def _decode_non_mp3_file_like(self, file, format=None):
if version.parse(sf.__libsndfile_version__) < version.parse("1.0.30"):
raise RuntimeError(
"Decoding .opus files requires 'libsndfile'>=1.0.30, "
+ "it can be installed via conda: `conda install -c conda-forge libsndfile>=1.0.30`"
+ 'it can be installed via conda: `conda install -c conda-forge "libsndfile>=1.0.30"`'
)
array, sampling_rate = sf.read(file)
array = array.T
Expand All @@ -282,19 +283,44 @@ def _decode_non_mp3_file_like(self, file, format=None):
def _decode_mp3(self, path_or_file):
try:
import torchaudio
import torchaudio.transforms as T
except ImportError as err:
raise ImportError(
"Decoding 'mp3' audio files, requires 'torchaudio<0.12.0': pip install 'torchaudio<0.12.0'"
) from err
if not version.parse(torchaudio.__version__) < version.parse("0.12.0"):
raise RuntimeError(
"Decoding 'mp3' audio files, requires 'torchaudio<0.12.0': pip install 'torchaudio<0.12.0'"
)
try:
torchaudio.set_audio_backend("sox_io")
except RuntimeError as err:
raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err
raise ImportError("To support decoding 'mp3' audio files, please install 'torchaudio'.") from err
if version.parse(torchaudio.__version__) < version.parse("0.12.0"):
try:
torchaudio.set_audio_backend("sox_io")
except RuntimeError as err:
raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err
array, sampling_rate = self._decode_mp3_torchaudio(path_or_file)
else:
try: # try torchaudio anyway because sometimes it works (depending on the os and os packages installed)
array, sampling_rate = self._decode_mp3_torchaudio(path_or_file)
except RuntimeError:
try:
# flake8: noqa
import librosa
except ImportError as err:
raise ImportError(
"Your version of `torchaudio` (>=0.12.0) doesn't support decoding 'mp3' files on your machine. "
"To support 'mp3' decoding with `torchaudio>=0.12.0`, please install `ffmpeg>=4` system package "
'or downgrade `torchaudio` to <0.12: `pip install "torchaudio<0.12"`. '
"To support decoding 'mp3' audio files without `torchaudio`, please install `librosa`: "
"`pip install librosa`. Note that decoding will be extremely slow in that case."
) from err
# try to decode with librosa for torchaudio>=0.12.0 as a workaround
warnings.warn("Decoding mp3 with `librosa` instead of `torchaudio`, decoding is slow.")
try:
array, sampling_rate = self._decode_mp3_librosa(path_or_file)
except RuntimeError as err:
raise RuntimeError(
"Decoding of 'mp3' failed, probably because of streaming mode "
"(`librosa` cannot decode 'mp3' file-like objects, only path-like)."
) from err

return array, sampling_rate

def _decode_mp3_torchaudio(self, path_or_file):
import torchaudio
import torchaudio.transforms as T

array, sampling_rate = torchaudio.load(path_or_file, format="mp3")
if self.sampling_rate and self.sampling_rate != sampling_rate:
Expand All @@ -306,3 +332,9 @@ def _decode_mp3(self, path_or_file):
if self.mono:
array = array.mean(axis=0)
return array, sampling_rate

def _decode_mp3_librosa(self, path_or_file):
import librosa

array, sampling_rate = librosa.load(path_or_file, mono=self.mono, sr=self.sampling_rate)
return array, sampling_rate
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(pytest.mark.unit)


def pytest_configure(config):
config.addinivalue_line("markers", "torchaudio_latest: mark test to run with torchaudio>=0.12")


@pytest.fixture(autouse=True)
def set_test_cache_config(tmp_path_factory, monkeypatch):
# test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work?
Expand Down
171 changes: 168 additions & 3 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import os
import tarfile
from contextlib import nullcontext
from unittest.mock import patch

import pyarrow as pa
import pytest

from datasets import Dataset, concatenate_datasets, load_dataset
from datasets.features import Audio, Features, Sequence, Value

from ..utils import require_libsndfile_with_opus, require_sndfile, require_sox, require_torchaudio
from ..utils import (
require_libsndfile_with_opus,
require_sndfile,
require_sox,
require_torchaudio,
require_torchaudio_latest,
)


@pytest.fixture()
Expand Down Expand Up @@ -135,6 +143,26 @@ def test_audio_decode_example_mp3(shared_datadir):
assert decoded_example["sampling_rate"] == 44100


@pytest.mark.torchaudio_latest
@require_torchaudio_latest
@pytest.mark.parametrize("torchaudio_failed", [False, True])
def test_audio_decode_example_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
audio = Audio()

with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
) if torchaudio_failed else nullcontext():

if torchaudio_failed:
load_mock.side_effect = RuntimeError()

decoded_example = audio.decode_example(audio.encode_example(audio_path))
assert decoded_example["path"] == audio_path
assert decoded_example["array"].shape == (110592,)
assert decoded_example["sampling_rate"] == 44100


@require_libsndfile_with_opus
def test_audio_decode_example_opus(shared_datadir):
audio_path = str(shared_datadir / "test_audio_48000.opus")
Expand Down Expand Up @@ -178,6 +206,34 @@ def test_audio_resampling_mp3_different_sampling_rates(shared_datadir):
assert decoded_example["sampling_rate"] == 48000


@pytest.mark.torchaudio_latest
@require_torchaudio_latest
@pytest.mark.parametrize("torchaudio_failed", [False, True])
def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest(shared_datadir, torchaudio_failed):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
audio_path2 = str(shared_datadir / "test_audio_16000.mp3")
audio = Audio(sampling_rate=48000)

# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
) if torchaudio_failed else nullcontext():
if torchaudio_failed:
load_mock.side_effect = RuntimeError()

decoded_example = audio.decode_example(audio.encode_example(audio_path))
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
assert decoded_example["path"] == audio_path
assert decoded_example["array"].shape == (120373,)
assert decoded_example["sampling_rate"] == 48000

decoded_example = audio.decode_example(audio.encode_example(audio_path2))
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
assert decoded_example["path"] == audio_path2
assert decoded_example["array"].shape == (122688,)
assert decoded_example["sampling_rate"] == 48000


@require_sndfile
def test_dataset_with_audio_feature(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
Expand Down Expand Up @@ -266,6 +322,38 @@ def test_dataset_with_audio_feature_tar_mp3(tar_mp3_path):
assert column[0]["sampling_rate"] == 44100


@pytest.mark.torchaudio_latest
@require_torchaudio_latest
def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest(tar_mp3_path):
# no test for librosa here because it doesn't support file-like objects, only paths
audio_filename = "test_audio_44100.mp3"
data = {"audio": []}
for file_path, file_obj in iter_archive(tar_mp3_path):
data["audio"].append({"path": file_path, "bytes": file_obj.read()})
break
features = Features({"audio": Audio()})
dset = Dataset.from_dict(data, features=features)
item = dset[0]
assert item.keys() == {"audio"}
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
assert item["audio"]["path"] == audio_filename
assert item["audio"]["array"].shape == (110592,)
assert item["audio"]["sampling_rate"] == 44100
batch = dset[:1]
assert batch.keys() == {"audio"}
assert len(batch["audio"]) == 1
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
assert batch["audio"][0]["path"] == audio_filename
assert batch["audio"][0]["array"].shape == (110592,)
assert batch["audio"][0]["sampling_rate"] == 44100
column = dset["audio"]
assert len(column) == 1
assert column[0].keys() == {"path", "array", "sampling_rate"}
assert column[0]["path"] == audio_filename
assert column[0]["array"].shape == (110592,)
assert column[0]["sampling_rate"] == 44100


@require_sndfile
def test_dataset_with_audio_feature_with_none():
data = {"audio": [None]}
Expand Down Expand Up @@ -328,7 +416,7 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir):


@require_sox
@require_sndfile
@require_torchaudio
def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
data = {"audio": [audio_path]}
Expand All @@ -355,6 +443,43 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
assert column[0]["sampling_rate"] == 16000


@pytest.mark.torchaudio_latest
@require_torchaudio_latest
@pytest.mark.parametrize("torchaudio_failed", [False, True])
def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
data = {"audio": [audio_path]}
features = Features({"audio": Audio(sampling_rate=16000)})
dset = Dataset.from_dict(data, features=features)

# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
) if torchaudio_failed else nullcontext():
if torchaudio_failed:
load_mock.side_effect = RuntimeError()

item = dset[0]
assert item.keys() == {"audio"}
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
assert item["audio"]["path"] == audio_path
assert item["audio"]["array"].shape == (40125,)
assert item["audio"]["sampling_rate"] == 16000
batch = dset[:1]
assert batch.keys() == {"audio"}
assert len(batch["audio"]) == 1
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
assert batch["audio"][0]["path"] == audio_path
assert batch["audio"][0]["array"].shape == (40125,)
assert batch["audio"][0]["sampling_rate"] == 16000
column = dset["audio"]
assert len(column) == 1
assert column[0].keys() == {"path", "array", "sampling_rate"}
assert column[0]["path"] == audio_path
assert column[0]["array"].shape == (40125,)
assert column[0]["sampling_rate"] == 16000


@require_sndfile
def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
Expand Down Expand Up @@ -386,7 +511,7 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):


@require_sox
@require_sndfile
@require_torchaudio
def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
data = {"audio": [audio_path]}
Expand Down Expand Up @@ -416,6 +541,46 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir)
assert column[0]["sampling_rate"] == 16000


@pytest.mark.torchaudio_latest
@require_torchaudio_latest
@pytest.mark.parametrize("torchaudio_failed", [False, True])
def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
data = {"audio": [audio_path]}
features = Features({"audio": Audio()})
dset = Dataset.from_dict(data, features=features)

# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
) if torchaudio_failed else nullcontext():
if torchaudio_failed:
load_mock.side_effect = RuntimeError()

item = dset[0]
assert item["audio"]["sampling_rate"] == 44100
dset = dset.cast_column("audio", Audio(sampling_rate=16000))
item = dset[0]
assert item.keys() == {"audio"}
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
assert item["audio"]["path"] == audio_path
assert item["audio"]["array"].shape == (40125,)
assert item["audio"]["sampling_rate"] == 16000
batch = dset[:1]
assert batch.keys() == {"audio"}
assert len(batch["audio"]) == 1
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
assert batch["audio"][0]["path"] == audio_path
assert batch["audio"][0]["array"].shape == (40125,)
assert batch["audio"][0]["sampling_rate"] == 16000
column = dset["audio"]
assert len(column) == 1
assert column[0].keys() == {"path", "array", "sampling_rate"}
assert column[0]["path"] == audio_path
assert column[0]["array"].shape == (40125,)
assert column[0]["sampling_rate"] == 16000


@pytest.mark.parametrize(
"build_data",
[
Expand Down
11 changes: 10 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,16 @@ def parse_flag_from_env(key, default=False):
find_library("sox") is None,
reason="test requires sox OS dependency; only available on non-Windows: 'sudo apt-get install sox'",
)
require_torchaudio = pytest.mark.skipif(find_spec("torchaudio") is None, reason="test requires torchaudio")
require_torchaudio = pytest.mark.skipif(
find_spec("torchaudio") is None
or version.parse(import_module("torchaudio").__version__) >= version.parse("0.12.0"),
reason="test requires torchaudio<0.12",
)
require_torchaudio_latest = pytest.mark.skipif(
find_spec("torchaudio") is None
or version.parse(import_module("torchaudio").__version__) < version.parse("0.12.0"),
reason="test requires torchaudio>=0.12",
)


def require_beam(test_case):
Expand Down