Skip to content
21 changes: 5 additions & 16 deletions tests/collections/speaker_tasks/test_diar_neural_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,11 @@

class TestNeuralDiarizerInference:
@pytest.mark.unit
@pytest.mark.parametrize(
"device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
@pytest.mark.skip(
reason="Unknown timeout issue: This function takes too long to run when it is run with the other tests."
)
@pytest.mark.pleasefixme
@pytest.mark.parametrize("num_speakers", [None, 1])
@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@pytest.mark.parametrize("num_speakers", [None])
@pytest.mark.parametrize("max_num_speakers", [4])
def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_num_speakers):
"""
Expand All @@ -47,13 +37,12 @@ def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers,
- Ensures temporary directory is emptied at the end of diarization
- Sanity check to ensure outputs from diarization are reasonable
"""
audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
audio_filenames = ['an90-fbbh-b.wav']
audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]

diarizer = NeuralDiarizer.from_pretrained(model_name='diar_msdd_telephonic').to(device)

out_dir = os.path.join(tmpdir, 'diarize_inference/')

assert diarizer.msdd_model.device.type == device.type
assert diarizer._speaker_model.device.type == device.type
for audio_path in audio_paths:
Expand Down
Loading