Skip to content

Commit fd3ee74

Browse files
authored
Fix LazyNeMoIterator supervision for multi-channel cuts (#14409)
* Fix LazyNeMoIterator supervision for multi-channel cuts Signed-off-by: Ante Jukić <[email protected]> * Apply isort and black reformatting Signed-off-by: anteju <[email protected]> --------- Signed-off-by: Ante Jukić <[email protected]> Signed-off-by: anteju <[email protected]> Co-authored-by: anteju <[email protected]>
1 parent c83adff commit fd3ee74

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

nemo/collections/common/data/lhotse/nemo_adapters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __iter__(self) -> Generator[Cut, None, None]:
127127
recording_id=cut.recording_id,
128128
start=0,
129129
duration=cut.duration,
130+
channel=cut.channel,
130131
text=data.get(self.text_field),
131132
language=data.get(self.lang_field),
132133
)

tests/collections/common/test_lhotse_nemo_adapters.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
import numpy as np
1616
import pytest
17-
from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment
17+
from lhotse import AudioSource, CutSet, MonoCut, MultiCut, Recording, SupervisionSegment
1818
from lhotse.serialization import save_to_jsonl
19-
from lhotse.testing.dummies import DummyManifest
19+
from lhotse.testing.dummies import DummyManifest, dummy_multi_cut, dummy_supervision
2020

2121
from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator
2222

@@ -41,6 +41,29 @@ def nemo_manifest_path(tmp_path_factory):
4141
return p
4242

4343

44+
@pytest.fixture
45+
def nemo_manifest_path_multichannel(tmp_path_factory):
46+
"""2 utterances of length 1s with 3 channels as a NeMo manifest."""
47+
tmpdir = tmp_path_factory.mktemp("nemo_data")
48+
cuts = CutSet.from_cuts(
49+
dummy_multi_cut(idx, supervisions=[dummy_supervision(idx)], channel=[0, 1, 2], with_data=True)
50+
for idx in range(0, 2)
51+
).save_audios(tmpdir, progress_bar=False)
52+
nemo = []
53+
for c in cuts:
54+
nemo.append(
55+
{
56+
"audio_filepath": c.recording.sources[0].source,
57+
"text": "irrelevant",
58+
"duration": c.duration,
59+
"lang": "en",
60+
}
61+
)
62+
p = tmpdir / "nemo_manifest_multichannel.json"
63+
save_to_jsonl(nemo, p)
64+
return p
65+
66+
4467
def test_lazy_nemo_iterator(nemo_manifest_path):
4568
cuts = CutSet(LazyNeMoIterator(nemo_manifest_path))
4669

@@ -78,6 +101,45 @@ def test_lazy_nemo_iterator(nemo_manifest_path):
78101
assert s.language == "en"
79102

80103

104+
def test_lazy_nemo_iterator_multichannel(nemo_manifest_path_multichannel):
105+
cuts = CutSet(LazyNeMoIterator(nemo_manifest_path_multichannel))
106+
107+
assert len(cuts) == 2
108+
109+
for c in cuts:
110+
assert isinstance(c, MultiCut)
111+
assert c.start == 0.0
112+
assert c.duration == 1.0
113+
assert c.num_channels == 3
114+
assert c.channel == [0, 1, 2] # cuts have three channels
115+
assert c.sampling_rate == 16000
116+
assert c.num_samples == 16000
117+
118+
assert c.has_recording
119+
assert isinstance(c.recording, Recording)
120+
assert c.recording.duration == 1.0
121+
assert c.recording.num_channels == 3
122+
assert c.recording.num_samples == 16000
123+
assert len(c.recording.sources) == 1
124+
assert isinstance(c.recording.sources[0], AudioSource)
125+
assert c.recording.sources[0].type == "file"
126+
assert c.recording.sources[0].channels == c.channel # recording has same channels as the cut
127+
128+
audio = c.load_audio()
129+
assert isinstance(audio, np.ndarray)
130+
assert audio.shape == (c.num_channels, 16000) # audio has same num_channels as the cut
131+
assert audio.dtype == np.float32
132+
133+
assert len(c.supervisions) == 1
134+
s = c.supervisions[0]
135+
assert isinstance(s, SupervisionSegment)
136+
assert s.start == 0
137+
assert s.duration == 1
138+
assert s.channel == c.channel # supervision has same channels as the cut
139+
assert s.text == "irrelevant"
140+
assert s.language == "en"
141+
142+
81143
@pytest.fixture
82144
def nemo_offset_manifest_path(tmp_path_factory):
83145
"""

0 commit comments

Comments
 (0)