Skip to content

Commit a2a572e

Browse files
authored
fix: reduce the excessive test time of test_msdd_diar_inference (#14366)
* Reduced the test time by 4 fold removing redundant testing Signed-off-by: taejinp <[email protected]> * Activated CPU test Signed-off-by: taejinp <[email protected]> * Add CPU and CUDA tests. remmoved print function Signed-off-by: taejinp <[email protected]> * Skipping this test since timeout issue is not solved Signed-off-by: taejinp <[email protected]> * Apply isort and black reformatting Signed-off-by: tango4j <[email protected]> --------- Signed-off-by: taejinp <[email protected]> Signed-off-by: tango4j <[email protected]> Co-authored-by: tango4j <[email protected]>
1 parent a82dce9 commit a2a572e

1 file changed

Lines changed: 5 additions & 16 deletions

File tree

tests/collections/speaker_tasks/test_diar_neural_inference.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,11 @@
2222

2323
class TestNeuralDiarizerInference:
2424
@pytest.mark.unit
25-
@pytest.mark.parametrize(
26-
"device",
27-
[
28-
torch.device("cpu"),
29-
pytest.param(
30-
torch.device("cuda"),
31-
marks=pytest.mark.skipif(
32-
not torch.cuda.is_available(),
33-
reason='CUDA required for test.',
34-
),
35-
),
36-
],
25+
@pytest.mark.skip(
26+
reason="Unknown timeout issue: This function takes too long to run when it is run with the other tests."
3727
)
38-
@pytest.mark.pleasefixme
39-
@pytest.mark.parametrize("num_speakers", [None, 1])
28+
@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda")])
29+
@pytest.mark.parametrize("num_speakers", [None])
4030
@pytest.mark.parametrize("max_num_speakers", [4])
4131
def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_num_speakers):
4232
"""
@@ -47,13 +37,12 @@ def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers,
4737
- Ensures temporary directory is emptied at the end of diarization
4838
- Sanity check to ensure outputs from diarization are reasonable
4939
"""
50-
audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
40+
audio_filenames = ['an90-fbbh-b.wav']
5141
audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]
5242

5343
diarizer = NeuralDiarizer.from_pretrained(model_name='diar_msdd_telephonic').to(device)
5444

5545
out_dir = os.path.join(tmpdir, 'diarize_inference/')
56-
5746
assert diarizer.msdd_model.device.type == device.type
5847
assert diarizer._speaker_model.device.type == device.type
5948
for audio_path in audio_paths:

0 commit comments

Comments
 (0)