Skip to content

Commit ff6d555

Browse files
committed
Fix LazyNeMoIterator supervision for multi-channel cuts
Signed-off-by: Ante Jukić <[email protected]>
1 parent 7d361dc commit ff6d555

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

nemo/collections/common/data/lhotse/nemo_adapters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __iter__(self) -> Generator[Cut, None, None]:
127127
recording_id=cut.recording_id,
128128
start=0,
129129
duration=cut.duration,
130+
channel=cut.channel,
130131
text=data.get(self.text_field),
131132
language=data.get(self.lang_field),
132133
)

tests/collections/common/test_lhotse_nemo_adapters.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
import numpy as np
1616
import pytest
17-
from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment
17+
from lhotse import AudioSource, CutSet, MonoCut, MultiCut, Recording, SupervisionSegment
1818
from lhotse.serialization import save_to_jsonl
19-
from lhotse.testing.dummies import DummyManifest
19+
from lhotse.testing.dummies import DummyManifest, dummy_multi_cut, dummy_supervision
2020

2121
from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator
2222

@@ -40,6 +40,25 @@ def nemo_manifest_path(tmp_path_factory):
4040
save_to_jsonl(nemo, p)
4141
return p
4242

43+
@pytest.fixture
44+
def nemo_manifest_path_multichannel(tmp_path_factory):
45+
"""2 utterances of length 1s with 3 channels as a NeMo manifest."""
46+
tmpdir = tmp_path_factory.mktemp("nemo_data")
47+
cuts = CutSet.from_cuts(dummy_multi_cut(idx, supervisions=[dummy_supervision(idx)], channel=[0, 1, 2], with_data=True) for idx in range(0, 2)).save_audios(tmpdir, progress_bar=False)
48+
nemo = []
49+
for c in cuts:
50+
nemo.append(
51+
{
52+
"audio_filepath": c.recording.sources[0].source,
53+
"text": "irrelevant",
54+
"duration": c.duration,
55+
"lang": "en",
56+
}
57+
)
58+
p = tmpdir / "nemo_manifest_multichannel.json"
59+
save_to_jsonl(nemo, p)
60+
return p
61+
4362

4463
def test_lazy_nemo_iterator(nemo_manifest_path):
4564
cuts = CutSet(LazyNeMoIterator(nemo_manifest_path))
@@ -77,6 +96,44 @@ def test_lazy_nemo_iterator(nemo_manifest_path):
7796
assert s.text == "irrelevant"
7897
assert s.language == "en"
7998

99+
def test_lazy_nemo_iterator_multichannel(nemo_manifest_path_multichannel):
100+
cuts = CutSet(LazyNeMoIterator(nemo_manifest_path_multichannel))
101+
102+
assert len(cuts) == 2
103+
104+
for c in cuts:
105+
assert isinstance(c, MultiCut)
106+
assert c.start == 0.0
107+
assert c.duration == 1.0
108+
assert c.num_channels == 3
109+
assert c.channel == [0, 1, 2] # cuts have three channels
110+
assert c.sampling_rate == 16000
111+
assert c.num_samples == 16000
112+
113+
assert c.has_recording
114+
assert isinstance(c.recording, Recording)
115+
assert c.recording.duration == 1.0
116+
assert c.recording.num_channels == 3
117+
assert c.recording.num_samples == 16000
118+
assert len(c.recording.sources) == 1
119+
assert isinstance(c.recording.sources[0], AudioSource)
120+
assert c.recording.sources[0].type == "file"
121+
assert c.recording.sources[0].channels == c.channel # recording has same channels as the cut
122+
123+
audio = c.load_audio()
124+
assert isinstance(audio, np.ndarray)
125+
assert audio.shape == (c.num_channels, 16000) # audio has same num_channels as the cut
126+
assert audio.dtype == np.float32
127+
128+
assert len(c.supervisions) == 1
129+
s = c.supervisions[0]
130+
assert isinstance(s, SupervisionSegment)
131+
assert s.start == 0
132+
assert s.duration == 1
133+
assert s.channel == c.channel # supervision has same channels as the cut
134+
assert s.text == "irrelevant"
135+
assert s.language == "en"
136+
80137

81138
@pytest.fixture
82139
def nemo_offset_manifest_path(tmp_path_factory):

0 commit comments

Comments
 (0)