Skip to content

Commit 8337553

Browse files
committed
Added unit test
Signed-off-by: Ante Jukić <[email protected]>
1 parent 9ed7e9b commit 8337553

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

tests/collections/common/test_lhotse_dataloading.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,51 @@ def nemo_manifest_path(cutset_path: Path):
104104
return p
105105

106106

107+
@pytest.fixture(scope="session")
108+
def mc_cutset_path(tmp_path_factory) -> Path:
109+
"""10 two-channel utterances of length 1s as a Lhotse CutSet."""
110+
from lhotse import CutSet, MultiCut
111+
from lhotse.testing.dummies import DummyManifest
112+
113+
num_examples = 10 # number of examples
114+
num_channels = 2 # number of channels per example
115+
116+
# create a dummy manifest with single-channel examples
117+
sc_cuts = DummyManifest(CutSet, begin_id=0, end_id=num_examples * num_channels, with_data=True)
118+
mc_cuts = []
119+
120+
for n in range(num_examples):
121+
# sources for individual channels
122+
mc_sources = []
123+
for channel in range(num_channels):
124+
source = sc_cuts[n * num_channels + channel].recording.sources[0]
125+
source.channels = [channel]
126+
mc_sources.append(source)
127+
128+
# merge recordings
129+
rec = Recording(
130+
sources=mc_sources,
131+
id=f'mc-dummy-recording-{n:02d}',
132+
num_samples=sc_cuts[0].num_samples,
133+
duration=sc_cuts[0].duration,
134+
sampling_rate=sc_cuts[0].sampling_rate,
135+
)
136+
137+
# multi-channel cut
138+
cut = MultiCut(
139+
recording=rec, id=f'mc-dummy-cut-{n:02d}', start=0, duration=1.0, channel=list(range(num_channels))
140+
)
141+
mc_cuts.append(cut)
142+
143+
mc_cuts = CutSet.from_cuts(mc_cuts)
144+
145+
tmp_path = tmp_path_factory.mktemp("data")
146+
p = tmp_path / "mc_cuts.jsonl.gz"
147+
pa = tmp_path / "mc_audio"
148+
mc_cuts.save_audios(pa).to_file(p)
149+
return p
150+
151+
107152
@pytest.fixture(scope="session")
108153
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> Tuple[str, str]:
109154
"""10 utterances of length 1s as a NeMo tarred manifest."""
@@ -247,6 +292,61 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
247292
# exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts
248293

249294

295+
def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
296+
# Dataloader without channel selector
297+
config = OmegaConf.create(
298+
{
299+
"cuts_path": mc_cutset_path,
300+
"sample_rate": 16000,
301+
"shuffle": True,
302+
"use_lhotse": True,
303+
"num_workers": 0,
304+
"batch_size": 4,
305+
"seed": 0,
306+
}
307+
)
308+
309+
dl = get_lhotse_dataloader_from_config(
310+
config=config, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
311+
)
312+
batches = [b for b in dl]
313+
assert len(batches) == 3
314+
315+
# 1.0s = 16000 samples, two channels, note the constant duration and batch size
316+
assert batches[0]["audio"].shape == (4, 2, 16000)
317+
assert batches[1]["audio"].shape == (4, 2, 16000)
318+
assert batches[2]["audio"].shape == (2, 2, 16000)
319+
# exactly 10 cuts were used
320+
321+
# Apply channel selector
322+
for channel_selector in [None, 0, 1]:
323+
324+
config_cs = OmegaConf.create(
325+
{
326+
"cuts_path": mc_cutset_path,
327+
"channel_selector": channel_selector,
328+
"sample_rate": 16000,
329+
"shuffle": True,
330+
"use_lhotse": True,
331+
"num_workers": 0,
332+
"batch_size": 4,
333+
"seed": 0,
334+
}
335+
)
336+
337+
dl_cs = get_lhotse_dataloader_from_config(
338+
config=config_cs, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
339+
)
340+
341+
for n, b_cs in enumerate(dl_cs):
342+
if channel_selector is None:
343+
# no channel selector, needs to match the original dataset
344+
assert torch.equal(b_cs["audio"], batches[n]["audio"])
345+
else:
346+
# channel selector, needs to match the selected channel
347+
assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :])
348+
349+
250350
@requires_torchaudio
251351
def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
252352
config = OmegaConf.create(

0 commit comments

Comments
 (0)