@@ -847,8 +847,16 @@ def get_process_fn(self, process):
847847 if process == "attn" :
848848 return self .process_attn_prior
849849
850- def check_data (self , filelist , word_seg_token = " " , heavy_clip_detction = False ):
850+ def check_data (
851+ self ,
852+ filelist ,
853+ word_seg_token = " " ,
854+ heavy_clip_detction = False ,
855+ heavy_objective_evaluation = False ,
856+ ):
851857 data = []
858+ if heavy_objective_evaluation :
859+ model = torchaudio .pipelines .SQUIM_OBJECTIVE .get_model ()
852860 # speaking rate (words/second, float, scatterplot or bar chart)
853861 # speaking rate (characters/second, float, scatterplot or bar chart)
854862 # articulation level (mean energy/speaking rate)
@@ -879,7 +887,7 @@ def check_data(self, filelist, word_seg_token=" ", heavy_clip_detction=False):
879887 n_phones = (
880888 len (phone_tokens .split ("/" )) if phone_tokens is not None else None
881889 )
882- audio , _ = torchaudio .load (
890+ audio , sr = torchaudio .load (
883891 str (
884892 self .create_path (
885893 item , "audio" , f"audio-{ self .input_sampling_rate } .wav"
@@ -890,6 +898,15 @@ def check_data(self, filelist, word_seg_token=" ", heavy_clip_detction=False):
890898 len (audio .size ()) == 1 or audio .size (0 ) == 1
891899 ), f"Audio has { audio .size (0 )} channels, but should be mono"
892900 audio = audio .squeeze ()
901+
902+ if heavy_objective_evaluation :
903+ # use objective metrics from https://pytorch.org/audio/main/tutorials/squim_tutorial.html
904+ if sr != 16000 :
905+ audio = torchaudio .functional .resample (audio , sr , 16000 )
906+ stoi_hyp , pesq_hyp , si_sdr_hyp = model (audio )
907+ data_point ["stoi" ] = float (stoi_hyp [0 ])
908+ data_point ["pesq" ] = float (pesq_hyp [0 ])
909+ data_point ["si_sdr" ] = float (si_sdr_hyp [0 ])
893910 if heavy_clip_detction :
894911 _ , total_clipping = detect_clipping (audio )
895912 else :
0 commit comments