7878from nemo .core .config import hydra_runner
7979from nemo .utils import logging
8080
81+
8182# fmt: off
8283
8384
@@ -93,11 +94,10 @@ class EvalBeamSearchNGramConfig:
9394 input_manifest : str = MISSING # The manifest file of the evaluation set
9495 kenlm_model_file : Optional [str ] = None # The path of the KenLM binary model file
9596 preds_output_folder : Optional [str ] = None # The optional folder where the predictions are stored
96- probs_cache_file : Optional [str ] = None # The cache file for storing the logprobs of the model
97+ hyps_cache_file : Optional [str ] = None # The cache file for storing the logprobs of the model
9798
9899 # Parameters for inference
99- acoustic_batch_size : int = 16 # The batch size to calculate log probabilities
100- beam_batch_size : int = 128 # The batch size to be used for beam search decoding
100+ batch_size : int = 16 # The batch size
101101 device : str = "cuda" # The device to load the model onto to calculate log probabilities
102102 use_amp : bool = False # Whether to use AMP if available to calculate log probabilities
103103
@@ -123,18 +123,31 @@ class EvalBeamSearchNGramConfig:
123123# fmt: on
124124
125125
126+ def apply_text_processing (
127+ punctuation_capitalization : PunctuationCapitalization , cfg : EvalBeamSearchNGramConfig , text : List [str ] | str
128+ ) -> List [str ] | str :
129+ is_list = isinstance (text , list )
130+ text_arr = text if is_list else [text ]
131+ if cfg .text_processing .do_lowercase :
132+ text_arr = punctuation_capitalization .do_lowercase (text_arr )
133+ if cfg .text_processing .rm_punctuation :
134+ text_arr = punctuation_capitalization .rm_punctuation (text_arr )
135+ if cfg .text_processing .separate_punctuation :
136+ text_arr = punctuation_capitalization .separate_punctuation (text_arr )
137+
138+ return text_arr if is_list else text_arr [0 ]
139+
140+
126141def beam_search_eval (
142+ audio_filepaths ,
127143 model : nemo_asr .models .ASRModel ,
128144 cfg : EvalBeamSearchNGramConfig ,
129- all_probs : List [torch .Tensor ],
130145 target_transcripts : List [str ],
131146 preds_output_file : str = None ,
132147 lm_path : str = None ,
133148 beam_alpha : float = 1.0 ,
134149 beam_beta : float = 0.0 ,
135150 beam_width : int = 128 ,
136- beam_batch_size : int = 128 ,
137- progress_bar : bool = True ,
138151 punctuation_capitalization : PunctuationCapitalization = None ,
139152):
140153 level = logging .getEffectiveLevel ()
@@ -159,80 +172,47 @@ def beam_search_eval(
159172 # Update model's decoding strategy
160173 if isinstance (model , EncDecHybridRNNTCTCModel ):
161174 model .change_decoding_strategy (model .cfg .decoding , decoder_type = 'ctc' )
162- decoding = model .ctc_decoding
163175 else :
164176 model .change_decoding_strategy (model .cfg .decoding )
165- decoding = model .decoding
166177 logging .setLevel (level )
167178
179+ all_hyps = model .transcribe (audio_filepaths , cfg .batch_size )
180+
168181 wer_dist_first = cer_dist_first = 0
169182 wer_dist_best = cer_dist_best = 0
170183 words_count = 0
171184 chars_count = 0
172- sample_idx = 0
173185 if preds_output_file :
174186 out_file = open (preds_output_file , 'w' , encoding = 'utf_8' , newline = '\n ' )
175187
176- if progress_bar :
177- it = tqdm (
178- range (int (np .ceil (len (all_probs ) / beam_batch_size ))),
179- desc = f"Beam search decoding with width={ beam_width } , alpha={ beam_alpha } , beta={ beam_beta } " ,
180- ncols = 120 ,
181- )
182- else :
183- it = range (int (np .ceil (len (all_probs ) / beam_batch_size )))
184- for batch_idx in it :
185- # disabling type checking
186- probs_batch = all_probs [batch_idx * beam_batch_size : (batch_idx + 1 ) * beam_batch_size ]
187- probs_lens = torch .tensor ([prob .shape [0 ] for prob in probs_batch ])
188- with torch .no_grad ():
189- packed_batch = torch .zeros (len (probs_batch ), max (probs_lens ), probs_batch [0 ].shape [- 1 ], device = 'cpu' )
190-
191- for prob_index in range (len (probs_batch )):
192- packed_batch [prob_index , : probs_lens [prob_index ], :] = torch .tensor (
193- probs_batch [prob_index ], device = packed_batch .device , dtype = packed_batch .dtype
194- )
188+ for batch_idx , nbest_hyp in enumerate (all_hyps ):
189+ target = target_transcripts [batch_idx ]
190+ target_split_w = target .split ()
191+ target_split_c = list (target )
192+ words_count += len (target_split_w )
193+ chars_count += len (target_split_c )
194+ wer_dist_min = cer_dist_min = float ("inf" )
195+ for candidate_idx , candidate in enumerate (nbest_hyp ):
196+ pred_text = apply_text_processing (punctuation_capitalization , cfg , candidate .text )
195197
196- beams_batch = decoding .ctc_decoder_predictions_tensor (
197- packed_batch ,
198- decoder_lengths = probs_lens ,
199- return_hypotheses = True ,
200- )
198+ pred_split_w = pred_text .split ()
199+ wer_dist = editdistance .eval (target_split_w , pred_split_w )
200+ pred_split_c = list (pred_text )
201+ cer_dist = editdistance .eval (target_split_c , pred_split_c )
201202
202- for beams_idx , beams in enumerate (beams_batch ):
203- target = target_transcripts [sample_idx + beams_idx ]
204- target_split_w = target .split ()
205- target_split_c = list (target )
206- words_count += len (target_split_w )
207- chars_count += len (target_split_c )
208- wer_dist_min = cer_dist_min = 10000
209- for candidate_idx , candidate in enumerate (beams ): # type: (int, ctc_beam_decoding.rnnt_utils.Hypothesis)
210- pred_text = candidate .text
211- if cfg .text_processing .do_lowercase :
212- pred_text = punctuation_capitalization .do_lowercase ([pred_text ])[0 ]
213- if cfg .text_processing .rm_punctuation :
214- pred_text = punctuation_capitalization .rm_punctuation ([pred_text ])[0 ]
215- if cfg .text_processing .separate_punctuation :
216- pred_text = punctuation_capitalization .separate_punctuation ([pred_text ])[0 ]
217- pred_split_w = pred_text .split ()
218- wer_dist = editdistance .eval (target_split_w , pred_split_w )
219- pred_split_c = list (pred_text )
220- cer_dist = editdistance .eval (target_split_c , pred_split_c )
221-
222- wer_dist_min = min (wer_dist_min , wer_dist )
223- cer_dist_min = min (cer_dist_min , cer_dist )
224-
225- if candidate_idx == 0 :
226- # first candidate
227- wer_dist_first += wer_dist
228- cer_dist_first += cer_dist
229-
230- score = candidate .score
231- if preds_output_file :
232- out_file .write ('{}\t {}\n ' .format (pred_text , score ))
233- wer_dist_best += wer_dist_min
234- cer_dist_best += cer_dist_min
235- sample_idx += len (probs_batch )
203+ wer_dist_min = min (wer_dist_min , wer_dist )
204+ cer_dist_min = min (cer_dist_min , cer_dist )
205+
206+ if candidate_idx == 0 :
207+ # first candidate
208+ wer_dist_first += wer_dist
209+ cer_dist_first += cer_dist
210+
211+ score = candidate .score
212+ if preds_output_file :
213+ out_file .write ('{}\t {}\n ' .format (pred_text , score ))
214+ wer_dist_best += wer_dist_min
215+ cer_dist_best += cer_dist_min
236216
237217 if preds_output_file :
238218 out_file .close ()
@@ -255,6 +235,7 @@ def beam_search_eval(
255235 wer_dist_best / words_count , cer_dist_best / chars_count
256236 )
257237 )
238+
258239 logging .info (f"=================================================================================" )
259240
260241 return wer_dist_first / words_count , cer_dist_first / chars_count
@@ -294,57 +275,39 @@ def main(cfg: EvalBeamSearchNGramConfig):
294275 audio_file_paths .append (str (audio_file .absolute ()))
295276
296277 punctuation_capitalization = PunctuationCapitalization (cfg .text_processing .punctuation_marks )
297- if cfg .text_processing .do_lowercase :
298- target_transcripts = punctuation_capitalization .do_lowercase (target_transcripts )
299- if cfg .text_processing .rm_punctuation :
300- target_transcripts = punctuation_capitalization .rm_punctuation (target_transcripts )
301- if cfg .text_processing .separate_punctuation :
302- target_transcripts = punctuation_capitalization .separate_punctuation (target_transcripts )
278+ target_transcripts = apply_text_processing (punctuation_capitalization , cfg , target_transcripts )
303279
304- if cfg .probs_cache_file and os .path .exists (cfg .probs_cache_file ):
305- logging .info (f"Found a pickle file of probabilities at '{ cfg .probs_cache_file } '." )
306- logging .info (f"Loading the cached pickle file of probabilities from '{ cfg .probs_cache_file } ' ..." )
307- with open (cfg .probs_cache_file , 'rb' ) as probs_file :
308- all_probs = pickle .load (probs_file )
280+ if cfg .hyps_cache_file and os .path .exists (cfg .hyps_cache_file ):
281+ logging .info (f"Found a pickle file of hypotheses at '{ cfg .hyps_cache_file } '." )
282+ logging .info (f"Loading the cached pickle file of hypotheses from '{ cfg .hyps_cache_file } ' ..." )
283+ with open (cfg .hyps_cache_file , 'rb' ) as probs_file :
284+ all_hyps = pickle .load (probs_file )
309285
310- if len (all_probs ) != len (audio_file_paths ):
286+ if len (all_hyps ) != len (audio_file_paths ):
311287 raise ValueError (
312- f"The number of samples in the probabilities file '{ cfg .probs_cache_file } ' does not "
313- f"match the manifest file. You may need to delete the probabilities cached file."
288+ f"The number of samples in the hypotheses file '{ cfg .hyps_cache_file } ' does not "
289+ f"match the manifest file. You may need to delete the hypotheses cached file."
314290 )
315291 else :
316292
317293 with torch .amp .autocast (asr_model .device .type , enabled = cfg .use_amp ):
318294 with torch .no_grad ():
319295 if isinstance (asr_model , EncDecHybridRNNTCTCModel ):
320296 asr_model .cur_decoder = 'ctc'
321- all_logits = asr_model .transcribe (audio_file_paths , batch_size = cfg .acoustic_batch_size , logprobs = True )
297+ all_hyps = asr_model .transcribe (audio_file_paths , batch_size = cfg .batch_size )
322298
323- all_probs = all_logits
324- if cfg .probs_cache_file :
325- os .makedirs (os .path .split (cfg .probs_cache_file )[0 ], exist_ok = True )
326- logging .info (f"Writing pickle files of probabilities at '{ cfg .probs_cache_file } '..." )
327- with open (cfg .probs_cache_file , 'wb' ) as f_dump :
328- pickle .dump (all_probs , f_dump )
299+ if cfg .hyps_cache_file :
300+ os .makedirs (os .path .split (cfg .hyps_cache_file )[0 ], exist_ok = True )
301+ logging .info (f"Writing pickle files of hypotheses at '{ cfg .hyps_cache_file } '..." )
302+ with open (cfg .hyps_cache_file , 'wb' ) as f_dump :
303+ pickle .dump (all_hyps , f_dump )
329304
330305 wer_dist_greedy = 0
331306 cer_dist_greedy = 0
332307 words_count = 0
333308 chars_count = 0
334- for batch_idx , probs in enumerate (all_probs ):
335- preds = np .argmax (probs , axis = 1 )
336- preds_tensor = torch .tensor (preds , device = 'cpu' ).unsqueeze (0 )
337- if isinstance (asr_model , EncDecHybridRNNTCTCModel ):
338- pred_text = asr_model .ctc_decoding .ctc_decoder_predictions_tensor (preds_tensor )[0 ]
339- else :
340- pred_text = asr_model ._wer .decoding .ctc_decoder_predictions_tensor (preds_tensor )[0 ]
341-
342- if cfg .text_processing .do_lowercase :
343- pred_text = punctuation_capitalization .do_lowercase ([pred_text ])[0 ]
344- if cfg .text_processing .rm_punctuation :
345- pred_text = punctuation_capitalization .rm_punctuation ([pred_text ])[0 ]
346- if cfg .text_processing .separate_punctuation :
347- pred_text = punctuation_capitalization .separate_punctuation ([pred_text ])[0 ]
309+ for batch_idx , hyp in enumerate (all_hyps ):
310+ pred_text = apply_text_processing (punctuation_capitalization , cfg , hyp .text )
348311
349312 pred_split_w = pred_text .split ()
350313 target_split_w = target_transcripts [batch_idx ].split ()
@@ -381,7 +344,7 @@ def main(cfg: EvalBeamSearchNGramConfig):
381344 best_wer_beam_size , best_cer_beam_size = None , None
382345 best_wer_alpha , best_cer_alpha = None , None
383346 best_wer_beta , best_cer_beta = None , None
384- best_wer , best_cer = 1e6 , 1e6
347+ best_wer , best_cer = float ( "inf" ), float ( "inf" )
385348
386349 logging .info (f"==============================Starting the beam search decoding===============================" )
387350 logging .info (f"Grid search size: { len (hp_grid )} " )
@@ -400,31 +363,33 @@ def main(cfg: EvalBeamSearchNGramConfig):
400363 preds_output_file = None
401364
402365 candidate_wer , candidate_cer = beam_search_eval (
366+ audio_file_paths ,
403367 asr_model ,
404368 cfg ,
405- all_probs = all_probs ,
406369 target_transcripts = target_transcripts ,
407370 preds_output_file = preds_output_file ,
408371 lm_path = lm_path ,
409372 beam_width = hp ["beam_width" ],
410373 beam_alpha = hp ["beam_alpha" ],
411374 beam_beta = hp ["beam_beta" ],
412- beam_batch_size = cfg .beam_batch_size ,
413- progress_bar = True ,
414375 punctuation_capitalization = punctuation_capitalization ,
415376 )
416377
417378 if candidate_cer < best_cer :
418- best_cer_beam_size = hp ["beam_width" ]
419- best_cer_alpha = hp ["beam_alpha" ]
420- best_cer_beta = hp ["beam_beta" ]
421- best_cer = candidate_cer
379+ best_cer_beam_size , best_cer_alpha , best_cer_beta , best_cer = (
380+ hp ["beam_width" ],
381+ hp ["beam_alpha" ],
382+ hp ["beam_beta" ],
383+ candidate_cer ,
384+ )
422385
423386 if candidate_wer < best_wer :
424- best_wer_beam_size = hp ["beam_width" ]
425- best_wer_alpha = hp ["beam_alpha" ]
426- best_wer_beta = hp ["beam_beta" ]
427- best_wer = candidate_wer
387+ best_wer_beam_size , best_wer_alpha , best_wer_beta , best_wer = (
388+ hp ["beam_width" ],
389+ hp ["beam_alpha" ],
390+ hp ["beam_beta" ],
391+ candidate_wer ,
392+ )
428393
429394 logging .info (
430395 f'Best WER Candidate = { best_wer :.2%} :: Beam size = { best_wer_beam_size } , '
0 commit comments