Skip to content

Commit 7da9121

Browse files
authored
[ASR] Support for transcription of multi-channel audio for AED models (NVIDIA-NeMo#9007)
* Propagate channel selector for AED model + add channel selector to get_lhotse_dataloader_from config Signed-off-by: Ante Jukić <[email protected]> * Included comments Signed-off-by: Ante Jukić <[email protected]> * Added unit test Signed-off-by: Ante Jukić <[email protected]> --------- Signed-off-by: Ante Jukić <[email protected]>
1 parent 2b6bd58 commit 7da9121

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

nemo/collections/asr/models/aed_multitask_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
875875
'drop_last': False,
876876
'text_field': config.get('text_field', 'answer'),
877877
'lang_field': config.get('lang_field', 'target_lang'),
878+
'channel_selector': config.get('channel_selector', None),
878879
}
879880

880881
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)

nemo/collections/common/data/lhotse/dataloader.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class LhotseDataLoadingConfig:
8989
seed: int | str = 0
9090
num_workers: int = 0
9191
pin_memory: bool = False
92+
channel_selector: int | str | None = None
9293

9394
# 4. Optional Lhotse data augmentation.
9495
# a. On-the-fly noise/audio mixing.
@@ -156,6 +157,11 @@ def get_lhotse_dataloader_from_config(
156157
# 1. Load a manifest as a Lhotse CutSet.
157158
cuts, is_tarred = read_cutset_from_config(config)
158159

160+
# Apply channel selector
161+
if config.channel_selector is not None:
162+
logging.info('Using channel selector %s.', config.channel_selector)
163+
cuts = cuts.map(partial(_select_channel, channel_selector=config.channel_selector))
164+
159165
# Resample as a safeguard; it's a no-op when SR is already OK
160166
cuts = cuts.resample(config.sample_rate)
161167

@@ -443,3 +449,25 @@ def _flatten_alt_text(cut) -> list:
443449
text_instance.custom = {"text": data.pop("text"), "lang": data.pop("lang"), **data}
444450
ans.append(text_instance)
445451
return ans
452+
453+
454+
def _select_channel(cut, channel_selector: int | str) -> list:
455+
if isinstance(channel_selector, int):
456+
channel_idx = channel_selector
457+
elif isinstance(channel_selector, str):
458+
if channel_selector in cut.custom:
459+
channel_idx = cut.custom[channel_selector]
460+
else:
461+
raise ValueError(f"Channel selector {channel_selector} not found in cut.custom")
462+
463+
if channel_idx >= cut.num_channels:
464+
raise ValueError(
465+
f"Channel index {channel_idx} is larger than the actual number of channels {cut.num_channels}"
466+
)
467+
468+
if cut.num_channels == 1:
469+
# one channel available and channel_idx==0
470+
return cut
471+
else:
472+
# with_channels only defined on MultiCut
473+
return cut.with_channels(channel_idx)

tests/collections/common/test_lhotse_dataloading.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,51 @@ def nemo_manifest_path(cutset_path: Path):
104104
return p
105105

106106

107+
@pytest.fixture(scope="session")
108+
def mc_cutset_path(tmp_path_factory) -> Path:
109+
"""10 two-channel utterances of length 1s as a Lhotse CutSet."""
110+
from lhotse import CutSet, MultiCut
111+
from lhotse.testing.dummies import DummyManifest
112+
113+
num_examples = 10 # number of examples
114+
num_channels = 2 # number of channels per example
115+
116+
# create a dummy manifest with single-channel examples
117+
sc_cuts = DummyManifest(CutSet, begin_id=0, end_id=num_examples * num_channels, with_data=True)
118+
mc_cuts = []
119+
120+
for n in range(num_examples):
121+
# sources for individual channels
122+
mc_sources = []
123+
for channel in range(num_channels):
124+
source = sc_cuts[n * num_channels + channel].recording.sources[0]
125+
source.channels = [channel]
126+
mc_sources.append(source)
127+
128+
# merge recordings
129+
rec = Recording(
130+
sources=mc_sources,
131+
id=f'mc-dummy-recording-{n:02d}',
132+
num_samples=sc_cuts[0].num_samples,
133+
duration=sc_cuts[0].duration,
134+
sampling_rate=sc_cuts[0].sampling_rate,
135+
)
136+
137+
# multi-channel cut
138+
cut = MultiCut(
139+
recording=rec, id=f'mc-dummy-cut-{n:02d}', start=0, duration=1.0, channel=list(range(num_channels))
140+
)
141+
mc_cuts.append(cut)
142+
143+
mc_cuts = CutSet.from_cuts(mc_cuts)
144+
145+
tmp_path = tmp_path_factory.mktemp("data")
146+
p = tmp_path / "mc_cuts.jsonl.gz"
147+
pa = tmp_path / "mc_audio"
148+
mc_cuts.save_audios(pa).to_file(p)
149+
return p
150+
151+
107152
@pytest.fixture(scope="session")
108153
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> Tuple[str, str]:
109154
"""10 utterances of length 1s as a NeMo tarred manifest."""
@@ -247,6 +292,61 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
247292
# exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts
248293

249294

295+
def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
296+
# Dataloader without channel selector
297+
config = OmegaConf.create(
298+
{
299+
"cuts_path": mc_cutset_path,
300+
"sample_rate": 16000,
301+
"shuffle": True,
302+
"use_lhotse": True,
303+
"num_workers": 0,
304+
"batch_size": 4,
305+
"seed": 0,
306+
}
307+
)
308+
309+
dl = get_lhotse_dataloader_from_config(
310+
config=config, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
311+
)
312+
batches = [b for b in dl]
313+
assert len(batches) == 3
314+
315+
# 1.0s = 16000 samples, two channels, note the constant duration and batch size
316+
assert batches[0]["audio"].shape == (4, 2, 16000)
317+
assert batches[1]["audio"].shape == (4, 2, 16000)
318+
assert batches[2]["audio"].shape == (2, 2, 16000)
319+
# exactly 10 cuts were used
320+
321+
# Apply channel selector
322+
for channel_selector in [None, 0, 1]:
323+
324+
config_cs = OmegaConf.create(
325+
{
326+
"cuts_path": mc_cutset_path,
327+
"channel_selector": channel_selector,
328+
"sample_rate": 16000,
329+
"shuffle": True,
330+
"use_lhotse": True,
331+
"num_workers": 0,
332+
"batch_size": 4,
333+
"seed": 0,
334+
}
335+
)
336+
337+
dl_cs = get_lhotse_dataloader_from_config(
338+
config=config_cs, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
339+
)
340+
341+
for n, b_cs in enumerate(dl_cs):
342+
if channel_selector is None:
343+
# no channel selector, needs to match the original dataset
344+
assert torch.equal(b_cs["audio"], batches[n]["audio"])
345+
else:
346+
# channel selector, needs to match the selected channel
347+
assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :])
348+
349+
250350
@requires_torchaudio
251351
def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
252352
config = OmegaConf.create(

0 commit comments

Comments
 (0)