Skip to content

Commit 11fdbfb

Browse files
committed
feat(preprocessor): add auto evaluation to check-data
1 parent b40375a commit 11fdbfb

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

everyvoice/preprocessor/preprocessor.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,16 @@ def get_process_fn(self, process):
847847
if process == "attn":
848848
return self.process_attn_prior
849849

850-
def check_data(self, filelist, word_seg_token=" ", heavy_clip_detction=False):
850+
def check_data(
851+
self,
852+
filelist,
853+
word_seg_token=" ",
854+
heavy_clip_detction=False,
855+
heavy_objective_evaluation=False,
856+
):
851857
data = []
858+
if heavy_objective_evaluation:
859+
model = torchaudio.pipelines.SQUIM_OBJECTIVE.get_model()
852860
# speaking rate (words/second, float, scatterplot or bar chart)
853861
# speaking rate (characters/second, float, scatterplot or bar chart)
854862
# articulation level (mean energy/speaking rate)
@@ -879,7 +887,7 @@ def check_data(self, filelist, word_seg_token=" ", heavy_clip_detction=False):
879887
n_phones = (
880888
len(phone_tokens.split("/")) if phone_tokens is not None else None
881889
)
882-
audio, _ = torchaudio.load(
890+
audio, sr = torchaudio.load(
883891
str(
884892
self.create_path(
885893
item, "audio", f"audio-{self.input_sampling_rate}.wav"
@@ -890,6 +898,15 @@ def check_data(self, filelist, word_seg_token=" ", heavy_clip_detction=False):
890898
len(audio.size()) == 1 or audio.size(0) == 1
891899
), f"Audio has {audio.size(0)} channels, but should be mono"
892900
audio = audio.squeeze()
901+
902+
if heavy_objective_evaluation:
903+
# use objective metrics from https://pytorch.org/audio/main/tutorials/squim_tutorial.html
904+
if sr != 16000:
905+
audio = torchaudio.functional.resample(audio, sr, 16000)
906+
stoi_hyp, pesq_hyp, si_sdr_hyp = model(audio)
907+
data_point["stoi"] = float(stoi_hyp[0])
908+
data_point["pesq"] = float(pesq_hyp[0])
909+
data_point["si_sdr"] = float(si_sdr_hyp[0])
893910
if heavy_clip_detction:
894911
_, total_clipping = detect_clipping(audio)
895912
else:

0 commit comments

Comments
 (0)