Skip to content

Commit 693418a

Browse files
YooSungHyunlhoestq
andauthored
Update: add using pcm bytes (#4323) (#4409)
* Update: add using pcm bytes * re make style * Update src/datasets/features/audio.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/features/audio.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/features/audio.py Co-authored-by: Quentin Lhoest <[email protected]> * delete: wrong comment * Update: sampling_rate usage & test source update * Update: pcm2wav bytes don`t need path we can open up soundfile lib Co-authored-by: Quentin Lhoest <[email protected]> * Update: we can get wav style bytes to pcm, so we can read to soundfile lib * Update: pcm doesn`t use path, so check 'None' * Update: not used self`s sampling_rate self.sampling_rate is for decode. so, we have to get value`s sampling_rate Co-authored-by: Quentin Lhoest <[email protected]> * Update: add sampling_rate we have to know sampling_rate in input values variable Co-authored-by: Quentin Lhoest <[email protected]> * Update: sampling_rate variable Co-authored-by: Quentin Lhoest <[email protected]> * Update tests/features/test_audio.py Co-authored-by: Quentin Lhoest <[email protected]> * Update tests/features/test_audio.py Co-authored-by: Quentin Lhoest <[email protected]> * Update tests/features/test_audio.py Co-authored-by: Quentin Lhoest <[email protected]> * Update: replace get sampling_rate * Apply suggestions from code review Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 2f1c41a commit 693418a

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

src/datasets/features/audio.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from io import BytesIO
44
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union
55

6+
import numpy as np
67
import pyarrow as pa
78
from packaging import version
89

@@ -92,7 +93,22 @@ def encode_example(self, value: Union[str, dict]) -> dict:
9293
return {"bytes": buffer.getvalue(), "path": None}
9394
elif value.get("path") is not None and os.path.isfile(value["path"]):
9495
# we set "bytes": None to not duplicate the data if they're already available locally
95-
return {"bytes": None, "path": value.get("path")}
96+
if value["path"].endswith("pcm"):
97+
# "PCM" only has raw audio bytes
98+
if value.get("sampling_rate") is None:
99+
# At least, If you want to convert "PCM-byte" to "WAV-byte", you have to know sampling rate
100+
raise KeyError("To use PCM files, please specify a 'sampling_rate' in Audio object")
101+
if value.get("bytes"):
102+
# If we already had PCM-byte, we don`t have to make "read file, make bytes" (just use it!)
103+
bytes_value = np.frombuffer(value["bytes"], dtype=np.int16).astype(np.float32) / 32767
104+
else:
105+
bytes_value = np.memmap(value["path"], dtype="h", mode="r").astype(np.float32) / 32767
106+
107+
buffer = BytesIO(bytes())
108+
sf.write(buffer, bytes_value, value["sampling_rate"], format="wav")
109+
return {"bytes": buffer.getvalue(), "path": None}
110+
else:
111+
return {"bytes": None, "path": value.get("path")}
96112
elif value.get("bytes") is not None or value.get("path") is not None:
97113
# store the audio bytes, and path is used to infer the audio format using the file extension
98114
return {"bytes": value.get("bytes"), "path": value.get("path")}
31.7 KB
Binary file not shown.

tests/features/test_audio.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ def test_audio_feature_encode_example(shared_datadir, build_example):
7878
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
7979

8080

81+
@pytest.mark.parametrize(
82+
"build_example",
83+
[
84+
lambda audio_path: {"path": audio_path, "sampling_rate": 16_000},
85+
lambda audio_path: {"path": audio_path, "bytes": None, "sampling_rate": 16_000},
86+
lambda audio_path: {"path": audio_path, "bytes": open(audio_path, "rb").read(), "sampling_rate": 16_000},
87+
lambda audio_path: {"array": [0.1, 0.2, 0.3], "sampling_rate": 16_000},
88+
],
89+
)
90+
def test_audio_feature_encode_example_pcm(shared_datadir, build_example):
91+
audio_path = str(shared_datadir / "test_audio_16000.pcm")
92+
audio = Audio(sampling_rate=16_000)
93+
encoded_example = audio.encode_example(build_example(audio_path))
94+
assert isinstance(encoded_example, dict)
95+
assert encoded_example.keys() == {"bytes", "path"}
96+
assert encoded_example["bytes"] is not None or encoded_example["path"] is not None
97+
decoded_example = audio.decode_example(encoded_example)
98+
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
99+
100+
81101
@require_sndfile
82102
def test_audio_decode_example(shared_datadir):
83103
audio_path = str(shared_datadir / "test_audio_44100.wav")
@@ -126,6 +146,18 @@ def test_audio_decode_example_opus(shared_datadir):
126146
assert decoded_example["sampling_rate"] == 48000
127147

128148

149+
@pytest.mark.parametrize("sampling_rate", [16_000, 48_000])
150+
def test_audio_decode_example_pcm(shared_datadir, sampling_rate):
151+
audio_path = str(shared_datadir / "test_audio_16000.pcm")
152+
audio_input = {"path": audio_path, "sampling_rate": 16_000}
153+
audio = Audio(sampling_rate=sampling_rate)
154+
decoded_example = audio.decode_example(audio.encode_example(audio_input))
155+
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
156+
assert decoded_example["path"] is None
157+
assert decoded_example["array"].shape == (16208 * sampling_rate // 16_000,)
158+
assert decoded_example["sampling_rate"] == sampling_rate
159+
160+
129161
@require_sox
130162
@require_torchaudio
131163
def test_audio_resampling_mp3_different_sampling_rates(shared_datadir):

0 commit comments

Comments
 (0)