@@ -104,6 +104,51 @@ def nemo_manifest_path(cutset_path: Path):
104104 return p
105105
106106
107+ @pytest .fixture (scope = "session" )
108+ def mc_cutset_path (tmp_path_factory ) -> Path :
109+ """10 two-channel utterances of length 1s as a Lhotse CutSet."""
110+ from lhotse import CutSet , MultiCut
111+ from lhotse .testing .dummies import DummyManifest
112+
113+ num_examples = 10 # number of examples
114+ num_channels = 2 # number of channels per example
115+
116+ # create a dummy manifest with single-channel examples
117+ sc_cuts = DummyManifest (CutSet , begin_id = 0 , end_id = num_examples * num_channels , with_data = True )
118+ mc_cuts = []
119+
120+ for n in range (num_examples ):
121+ # sources for individual channels
122+ mc_sources = []
123+ for channel in range (num_channels ):
124+ source = sc_cuts [n * num_channels + channel ].recording .sources [0 ]
125+ source .channels = [channel ]
126+ mc_sources .append (source )
127+
128+ # merge recordings
129+ rec = Recording (
130+ sources = mc_sources ,
131+ id = f'mc-dummy-recording-{ n :02d} ' ,
132+ num_samples = sc_cuts [0 ].num_samples ,
133+ duration = sc_cuts [0 ].duration ,
134+ sampling_rate = sc_cuts [0 ].sampling_rate ,
135+ )
136+
137+ # multi-channel cut
138+ cut = MultiCut (
139+ recording = rec , id = f'mc-dummy-cut-{ n :02d} ' , start = 0 , duration = 1.0 , channel = list (range (num_channels ))
140+ )
141+ mc_cuts .append (cut )
142+
143+ mc_cuts = CutSet .from_cuts (mc_cuts )
144+
145+ tmp_path = tmp_path_factory .mktemp ("data" )
146+ p = tmp_path / "mc_cuts.jsonl.gz"
147+ pa = tmp_path / "mc_audio"
148+ mc_cuts .save_audios (pa ).to_file (p )
149+ return p
150+
151+
107152@pytest .fixture (scope = "session" )
108153def nemo_tarred_manifest_path (nemo_manifest_path : Path ) -> Tuple [str , str ]:
109154 """10 utterances of length 1s as a NeMo tarred manifest."""
@@ -247,6 +292,61 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
247292 # exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts
248293
249294
295+ def test_dataloader_from_lhotse_cuts_channel_selector (mc_cutset_path : Path ):
296+ # Dataloader without channel selector
297+ config = OmegaConf .create (
298+ {
299+ "cuts_path" : mc_cutset_path ,
300+ "sample_rate" : 16000 ,
301+ "shuffle" : True ,
302+ "use_lhotse" : True ,
303+ "num_workers" : 0 ,
304+ "batch_size" : 4 ,
305+ "seed" : 0 ,
306+ }
307+ )
308+
309+ dl = get_lhotse_dataloader_from_config (
310+ config = config , global_rank = 0 , world_size = 1 , dataset = UnsupervisedAudioDataset ()
311+ )
312+ batches = [b for b in dl ]
313+ assert len (batches ) == 3
314+
315+ # 1.0s = 16000 samples, two channels, note the constant duration and batch size
316+ assert batches [0 ]["audio" ].shape == (4 , 2 , 16000 )
317+ assert batches [1 ]["audio" ].shape == (4 , 2 , 16000 )
318+ assert batches [2 ]["audio" ].shape == (2 , 2 , 16000 )
319+ # exactly 10 cuts were used
320+
321+ # Apply channel selector
322+ for channel_selector in [None , 0 , 1 ]:
323+
324+ config_cs = OmegaConf .create (
325+ {
326+ "cuts_path" : mc_cutset_path ,
327+ "channel_selector" : channel_selector ,
328+ "sample_rate" : 16000 ,
329+ "shuffle" : True ,
330+ "use_lhotse" : True ,
331+ "num_workers" : 0 ,
332+ "batch_size" : 4 ,
333+ "seed" : 0 ,
334+ }
335+ )
336+
337+ dl_cs = get_lhotse_dataloader_from_config (
338+ config = config_cs , global_rank = 0 , world_size = 1 , dataset = UnsupervisedAudioDataset ()
339+ )
340+
341+ for n , b_cs in enumerate (dl_cs ):
342+ if channel_selector is None :
343+ # no channel selector, needs to match the original dataset
344+ assert torch .equal (b_cs ["audio" ], batches [n ]["audio" ])
345+ else :
346+ # channel selector, needs to match the selected channel
347+ assert torch .equal (b_cs ["audio" ], batches [n ]["audio" ][:, channel_selector , :])
348+
349+
250350@requires_torchaudio
251351def test_dataloader_from_lhotse_shar_cuts (cutset_shar_path : Path ):
252352 config = OmegaConf .create (
0 commit comments