Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ class Audio:
Args:
sampling_rate (`int`, *optional*):
Target sampling rate. If `None`, the native sampling rate is used.
mono (`bool`, defaults to `True`):
num_channels (`int`, *optional*):
The desired number of channels of the decoded samples. By default, the number of channels of the source is used.
Currently `None` (number of channels of the source), `1` (mono) or `2` (stereo) channels are supported.
mono (Optiona[`bool`], defaults to `None`):
Whether to convert the audio signal to mono by averaging samples across
channels.
channels. If `None`, the audio signal is left in its original number of channels.
decode (`bool`, defaults to `True`):
Whether to decode the audio data. If `False`,
returns the underlying dictionary in the format `{"path": audio_path, "bytes": audio_bytes}`.
Expand All @@ -78,6 +81,7 @@ class Audio:

sampling_rate: Optional[int] = None
decode: bool = True
num_channels: Optional[int] = None
stream_index: Optional[int] = None
id: Optional[str] = field(default=None, repr=False)
# Automatically constructed
Expand Down Expand Up @@ -126,7 +130,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder
buffer = BytesIO()
AudioEncoder(
torch.from_numpy(value["array"].astype(np.float32)), sample_rate=value["sampling_rate"]
).to_file_like(buffer, format="wav")
).to_file_like(buffer, format="wav", num_channels=self.num_channels)
return {"bytes": buffer.getvalue(), "path": None}
elif value.get("path") is not None and os.path.isfile(value["path"]):
# we set "bytes": None to not duplicate the data if they're already available locally
Expand All @@ -143,7 +147,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder

buffer = BytesIO()
AudioEncoder(torch.from_numpy(bytes_value), sample_rate=value["sampling_rate"]).to_file_like(
buffer, format="wav"
buffer, format="wav", num_channels=self.num_channels
)
return {"bytes": buffer.getvalue(), "path": None}
else:
Expand Down Expand Up @@ -188,7 +192,9 @@ def decode_example(
raise ValueError(f"An audio sample should have one of 'path' or 'bytes' but both are None in {value}.")

if bytes is None and is_local_path(path):
audio = AudioDecoder(path, stream_index=self.stream_index, sample_rate=self.sampling_rate)
audio = AudioDecoder(
path, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
)

elif bytes is None:
token_per_repo_id = token_per_repo_id or {}
Expand All @@ -201,10 +207,14 @@ def decode_example(

download_config = DownloadConfig(token=token)
f = xopen(path, "rb", download_config=download_config)
audio = AudioDecoder(f, stream_index=self.stream_index, sample_rate=self.sampling_rate)
audio = AudioDecoder(
f, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
)

else:
audio = AudioDecoder(bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate)
audio = AudioDecoder(
bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
)
audio._hf_encoded = {"path": path, "bytes": bytes}
audio.metadata.path = path
return audio
Expand Down Expand Up @@ -312,5 +322,8 @@ def encode_torchcodec_audio(audio: "AudioDecoder") -> dict:

samples = audio.get_all_samples()
buffer = BytesIO()
AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like(buffer, format="wav")
num_channels = samples.data.shape[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't strictly necessary but added it to be explicit

AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like(
buffer, format="wav", num_channels=num_channels
)
return {"bytes": buffer.getvalue(), "path": None}
14 changes: 14 additions & 0 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,17 @@ def test_audio_embed_storage(shared_datadir):
embedded_storage = Audio().embed_storage(storage)
embedded_example = embedded_storage.to_pylist()[0]
assert embedded_example == {"bytes": open(audio_path, "rb").read(), "path": "test_audio_44100.wav"}


@require_torchcodec
def test_audio_decode_example_opus_convert_to_stereo(shared_datadir):
# GH 7837
from torchcodec.decoders import AudioDecoder

audio_path = str(shared_datadir / "test_audio_48000.opus")
audio = Audio(num_channels=2)
decoded_example = audio.decode_example(audio.encode_example(audio_path))
assert isinstance(decoded_example, AudioDecoder)
samples = decoded_example.get_all_samples()
assert samples.sample_rate == 48000
assert samples.data.shape == (2, 48000)