Skip to content

Commit 1482aec

Browse files
committed
feat: add everyvoice evaluate cli
currently uses torchsquim
1 parent 4c8bf94 commit 1482aec

File tree

3 files changed

+202
-1
lines changed

3 files changed

+202
-1
lines changed

everyvoice/cli.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import sys
33
from enum import Enum
44
from pathlib import Path
5-
from typing import Any, List
5+
from typing import Any, List, Optional
66

77
import typer
8+
from rich import print as rich_print
9+
from rich.panel import Panel
810

911
from everyvoice._version import VERSION
1012
from everyvoice.base_cli.checkpoint import inspect as inspect_checkpoint
@@ -80,10 +82,126 @@ def list_commands(self, ctx):
8082
## Synthesize
8183
8284
Once you have a trained model, generate some audio by running: everyvoice synthesize [text-to-spec|spec-to-wav] [OPTIONS]
85+
86+
## Evaluate
87+
88+
You can also try to evaluate your model by running: everyvoice evaluate [synthesized_audio.wav|folder_containing_wavs] [OPTIONS]
89+
8390
""",
8491
)
8592

8693

94+
@app.command(
95+
short_help="Evaluate your synthesized audio",
96+
name="evaluate",
97+
help="""
98+
# Evalution help
99+
100+
This command will evaluate an audio file, or a folder containing multiple audio files. Currently this is done by calculating the metrics from Kumar et. al. 2023.
101+
We will report the predicted Wideband Perceptual Estimation of Speech Quality (PESQ), Short-Time Objective Intelligibility (STOI), and Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) by default.
102+
We will also report the estimation of subjective Mean Opinion Score (MOS) if a Non-Matching Reference is provided. Please refer to Kumar et. al. for more information.
103+
104+
105+
106+
Kumar, Anurag, et al. “TorchAudio-Squim: Reference-less Speech Quality and Intelligibility measures in TorchAudio.” ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2023.
107+
""",
108+
)
109+
def evaluate(
110+
audio_file: Optional[Path] = typer.Option(
111+
None,
112+
"--audio-file",
113+
"-f",
114+
exists=True,
115+
dir_okay=False,
116+
file_okay=True,
117+
help="The path to a single audio file for evaluation.",
118+
autocompletion=complete_path,
119+
),
120+
audio_directory: Optional[Path] = typer.Option(
121+
None,
122+
"--audio-directory",
123+
"-d",
124+
file_okay=False,
125+
dir_okay=True,
126+
help="The directory where multiple audio files are located for evaluation",
127+
autocompletion=complete_path,
128+
),
129+
non_matching_reference: Optional[Path] = typer.Option(
130+
None,
131+
"--non-matching-reference",
132+
"-r",
133+
exists=True,
134+
dir_okay=False,
135+
file_okay=True,
136+
help="The path to a Non Mathing Reference audio file, required for MOS prediction.",
137+
autocompletion=complete_path,
138+
),
139+
):
140+
from tabulate import tabulate
141+
from tqdm import tqdm
142+
143+
from everyvoice.evaluation import (
144+
calculate_objective_metrics_from_single_path,
145+
calculate_subjective_metrics_from_single_path,
146+
load_squim_objective_model,
147+
load_squim_subjective_model,
148+
)
149+
150+
HEADERS = ["STOI", "PESQ", "SI-SDR"]
151+
152+
objective_model, o_sr = load_squim_objective_model()
153+
if non_matching_reference:
154+
subjective_model, s_sr = load_squim_subjective_model()
155+
HEADERS.append("MOS")
156+
157+
if audio_file and audio_directory:
158+
print(
159+
"Sorry, please choose to evaluate either a single file or an entire directory. Got values for both."
160+
)
161+
sys.exit(1)
162+
163+
def calculate_row(single_audio):
164+
stoi, pesq, si_sdr = calculate_objective_metrics_from_single_path(
165+
single_audio, objective_model, o_sr
166+
)
167+
row = [stoi, pesq, si_sdr]
168+
if non_matching_reference:
169+
mos = calculate_subjective_metrics_from_single_path(
170+
single_audio, non_matching_reference, subjective_model, s_sr
171+
)
172+
row.append(mos)
173+
return row
174+
175+
if audio_file:
176+
row = calculate_row(audio_file)
177+
rich_print(
178+
Panel(
179+
tabulate([row], HEADERS, tablefmt="simple"),
180+
title=f"Objective Metrics for {audio_file}:",
181+
)
182+
)
183+
sys.exit(0)
184+
185+
if audio_directory:
186+
# HEADERS = ["Average " + x for x in HEADERS]
187+
results = []
188+
for wav_file in tqdm(
189+
audio_directory.glob("*.wav"),
190+
desc=f"Evaluating filies in {audio_directory}",
191+
):
192+
row = calculate_row(wav_file)
193+
results.append(row)
194+
rich_print(
195+
Panel(
196+
tabulate(results, HEADERS, tablefmt="simple"),
197+
title=f"Objective Metrics for files in {audio_directory}:",
198+
)
199+
)
200+
print(f"Printing results to {audio_directory / 'evaluation.json'}")
201+
with open(audio_directory / "evaluation.json", "w") as f:
202+
json.dump(results, f)
203+
204+
87205
class ModelTypes(str, Enum):
88206
text_to_spec = "text-to-spec"
89207
spec_to_wav = "spec-to-wav"

everyvoice/evaluation.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from os import PathLike
2+
from typing import Any, BinaryIO, Union
3+
4+
5+
def load_squim_objective_model() -> tuple[Any, int]:
6+
"""Load the objective Squim Model. See https://pytorch.org/audio/main/tutorials/squim_tutorial.html
7+
8+
Returns:
9+
tuple[Any, int]: a tuple containing the model and the required sampling rate
10+
"""
11+
from torchaudio.pipelines import SQUIM_OBJECTIVE
12+
13+
model = SQUIM_OBJECTIVE.get_model()
14+
model_sampling_rate = 16000
15+
return (model, model_sampling_rate)
16+
17+
18+
def load_squim_subjective_model() -> tuple[Any, int]:
19+
"""Load the subjective Squim Model. See https://pytorch.org/audio/main/tutorials/squim_tutorial.html
20+
21+
Returns:
22+
tuple[Any, int]: a tuple containing the model and the required sampling rate
23+
"""
24+
from torchaudio.pipelines import SQUIM_SUBJECTIVE
25+
26+
model = SQUIM_SUBJECTIVE.get_model()
27+
model_sampling_rate = 16000
28+
return (model, model_sampling_rate)
29+
30+
31+
def process_audio(path: Union[BinaryIO, str, PathLike], sampling_rate: int):
32+
import torchaudio
33+
34+
audio, sr = torchaudio.load(str(path))
35+
# Must be 16 kHz
36+
if sr != sampling_rate:
37+
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
38+
# Must have channel dimension
39+
if len(audio.size()) < 2:
40+
audio = audio.unsqueeze(0)
41+
# Must be mono audio
42+
if audio.size(0) != 1:
43+
raise ValueError("Audio for evaluation must be mono (single channel)")
44+
return audio
45+
46+
47+
def calculate_objective_metrics_from_single_path(
48+
audio_path, model, model_sampling_rate
49+
) -> tuple[float, float, float]:
50+
import torch
51+
52+
audio = process_audio(audio_path, model_sampling_rate)
53+
with torch.no_grad():
54+
stoi_hyp, pesq_hyp, si_sdr_hyp = model(audio)
55+
return float(stoi_hyp), float(pesq_hyp), float(si_sdr_hyp)
56+
57+
58+
def calculate_subjective_metrics_from_single_path(
59+
audio_path, non_matching_reference_path, model, model_sampling_rate
60+
) -> float:
61+
import torch
62+
63+
audio = process_audio(audio_path, model_sampling_rate)
64+
nmr_audio = process_audio(non_matching_reference_path, model_sampling_rate)
65+
with torch.no_grad():
66+
mos = model(audio, nmr_audio)
67+
return float(mos)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from everyvoice.evaluation import (
2+
calculate_objective_metrics_from_single_path,
3+
load_squim_objective_model,
4+
)
5+
from everyvoice.tests.basic_test_case import BasicTestCase
6+
7+
8+
class EvaluationTest(BasicTestCase):
9+
def test_squim_evaluation(self):
10+
model, sr = load_squim_objective_model()
11+
stoi, pesq, si_sdr = calculate_objective_metrics_from_single_path(
12+
self.data_dir / "LJ010-0008.wav", model, sr
13+
)
14+
self.assertLess(stoi, 1)
15+
self.assertEqual(round(pesq, 2), 3.88)
16+
self.assertEqual(round(si_sdr, 2), 28.64)

0 commit comments

Comments
 (0)