7878from nemo .core .config import hydra_runner
7979from nemo .utils import logging
8080
81-
8281# fmt: off
8382
8483
@@ -94,10 +93,11 @@ class EvalBeamSearchNGramConfig:
9493 input_manifest : str = MISSING # The manifest file of the evaluation set
9594 kenlm_model_file : Optional [str ] = None # The path of the KenLM binary model file
9695 preds_output_folder : Optional [str ] = None # The optional folder where the predictions are stored
97- hyps_cache_file : Optional [str ] = None # The cache file for storing the logprobs of the model
96+ probs_cache_file : Optional [str ] = None # The cache file for storing the logprobs of the model
9897
9998 # Parameters for inference
100- batch_size : int = 16 # The batch size
99+ acoustic_batch_size : int = 16 # The batch size to calculate log probabilities
100+ beam_batch_size : int = 128 # The batch size to be used for beam search decoding
101101 device : str = "cuda" # The device to load the model onto to calculate log probabilities
102102 use_amp : bool = False # Whether to use AMP if available to calculate log probabilities
103103
@@ -123,31 +123,18 @@ class EvalBeamSearchNGramConfig:
123123# fmt: on
124124
125125
126- def apply_text_processing (
127- punctuation_capitalization : PunctuationCapitalization , cfg : EvalBeamSearchNGramConfig , text : List [str ] | str
128- ) -> List [str ] | str :
129- is_list = isinstance (text , list )
130- text_arr = text if is_list else [text ]
131- if cfg .text_processing .do_lowercase :
132- text_arr = punctuation_capitalization .do_lowercase (text_arr )
133- if cfg .text_processing .rm_punctuation :
134- text_arr = punctuation_capitalization .rm_punctuation (text_arr )
135- if cfg .text_processing .separate_punctuation :
136- text_arr = punctuation_capitalization .separate_punctuation (text_arr )
137-
138- return text_arr if is_list else text_arr [0 ]
139-
140-
141126def beam_search_eval (
142- audio_filepaths ,
143127 model : nemo_asr .models .ASRModel ,
144128 cfg : EvalBeamSearchNGramConfig ,
129+ all_probs : List [torch .Tensor ],
145130 target_transcripts : List [str ],
146131 preds_output_file : str = None ,
147132 lm_path : str = None ,
148133 beam_alpha : float = 1.0 ,
149134 beam_beta : float = 0.0 ,
150135 beam_width : int = 128 ,
136+ beam_batch_size : int = 128 ,
137+ progress_bar : bool = True ,
151138 punctuation_capitalization : PunctuationCapitalization = None ,
152139):
153140 level = logging .getEffectiveLevel ()
@@ -172,47 +159,80 @@ def beam_search_eval(
172159 # Update model's decoding strategy
173160 if isinstance (model , EncDecHybridRNNTCTCModel ):
174161 model .change_decoding_strategy (model .cfg .decoding , decoder_type = 'ctc' )
162+ decoding = model .ctc_decoding
175163 else :
176164 model .change_decoding_strategy (model .cfg .decoding )
165+ decoding = model .decoding
177166 logging .setLevel (level )
178167
179- all_hyps = model .transcribe (audio_filepaths , cfg .batch_size )
180-
181168 wer_dist_first = cer_dist_first = 0
182169 wer_dist_best = cer_dist_best = 0
183170 words_count = 0
184171 chars_count = 0
172+ sample_idx = 0
185173 if preds_output_file :
186174 out_file = open (preds_output_file , 'w' , encoding = 'utf_8' , newline = '\n ' )
187175
188- for batch_idx , nbest_hyp in enumerate (all_hyps ):
189- target = target_transcripts [batch_idx ]
190- target_split_w = target .split ()
191- target_split_c = list (target )
192- words_count += len (target_split_w )
193- chars_count += len (target_split_c )
194- wer_dist_min = cer_dist_min = float ("inf" )
195- for candidate_idx , candidate in enumerate (nbest_hyp ):
196- pred_text = apply_text_processing (punctuation_capitalization , cfg , candidate .text )
197-
198- pred_split_w = pred_text .split ()
199- wer_dist = editdistance .eval (target_split_w , pred_split_w )
200- pred_split_c = list (pred_text )
201- cer_dist = editdistance .eval (target_split_c , pred_split_c )
202-
203- wer_dist_min = min (wer_dist_min , wer_dist )
204- cer_dist_min = min (cer_dist_min , cer_dist )
176+ if progress_bar :
177+ it = tqdm (
178+ range (int (np .ceil (len (all_probs ) / beam_batch_size ))),
179+ desc = f"Beam search decoding with width={ beam_width } , alpha={ beam_alpha } , beta={ beam_beta } " ,
180+ ncols = 120 ,
181+ )
182+ else :
183+ it = range (int (np .ceil (len (all_probs ) / beam_batch_size )))
184+ for batch_idx in it :
185+ # disabling type checking
186+ probs_batch = all_probs [batch_idx * beam_batch_size : (batch_idx + 1 ) * beam_batch_size ]
187+ probs_lens = torch .tensor ([prob .shape [0 ] for prob in probs_batch ])
188+ with torch .no_grad ():
189+ packed_batch = torch .zeros (len (probs_batch ), max (probs_lens ), probs_batch [0 ].shape [- 1 ], device = 'cpu' )
190+
191+ for prob_index in range (len (probs_batch )):
192+ packed_batch [prob_index , : probs_lens [prob_index ], :] = torch .tensor (
193+ probs_batch [prob_index ], device = packed_batch .device , dtype = packed_batch .dtype
194+ )
205195
206- if candidate_idx == 0 :
207- # first candidate
208- wer_dist_first += wer_dist
209- cer_dist_first += cer_dist
196+ beams_batch = decoding .ctc_decoder_predictions_tensor (
197+ packed_batch ,
198+ decoder_lengths = probs_lens ,
199+ return_hypotheses = True ,
200+ )
210201
211- score = candidate .score
212- if preds_output_file :
213- out_file .write ('{}\t {}\n ' .format (pred_text , score ))
214- wer_dist_best += wer_dist_min
215- cer_dist_best += cer_dist_min
202+ for beams_idx , beams in enumerate (beams_batch ):
203+ target = target_transcripts [sample_idx + beams_idx ]
204+ target_split_w = target .split ()
205+ target_split_c = list (target )
206+ words_count += len (target_split_w )
207+ chars_count += len (target_split_c )
208+ wer_dist_min = cer_dist_min = 10000
209+ for candidate_idx , candidate in enumerate (beams ): # type: (int, ctc_beam_decoding.rnnt_utils.Hypothesis)
210+ pred_text = candidate .text
211+ if cfg .text_processing .do_lowercase :
212+ pred_text = punctuation_capitalization .do_lowercase ([pred_text ])[0 ]
213+ if cfg .text_processing .rm_punctuation :
214+ pred_text = punctuation_capitalization .rm_punctuation ([pred_text ])[0 ]
215+ if cfg .text_processing .separate_punctuation :
216+ pred_text = punctuation_capitalization .separate_punctuation ([pred_text ])[0 ]
217+ pred_split_w = pred_text .split ()
218+ wer_dist = editdistance .eval (target_split_w , pred_split_w )
219+ pred_split_c = list (pred_text )
220+ cer_dist = editdistance .eval (target_split_c , pred_split_c )
221+
222+ wer_dist_min = min (wer_dist_min , wer_dist )
223+ cer_dist_min = min (cer_dist_min , cer_dist )
224+
225+ if candidate_idx == 0 :
226+ # first candidate
227+ wer_dist_first += wer_dist
228+ cer_dist_first += cer_dist
229+
230+ score = candidate .score
231+ if preds_output_file :
232+ out_file .write ('{}\t {}\n ' .format (pred_text , score ))
233+ wer_dist_best += wer_dist_min
234+ cer_dist_best += cer_dist_min
235+ sample_idx += len (probs_batch )
216236
217237 if preds_output_file :
218238 out_file .close ()
@@ -235,7 +255,6 @@ def beam_search_eval(
235255 wer_dist_best / words_count , cer_dist_best / chars_count
236256 )
237257 )
238-
239258 logging .info (f"=================================================================================" )
240259
241260 return wer_dist_first / words_count , cer_dist_first / chars_count
@@ -275,39 +294,57 @@ def main(cfg: EvalBeamSearchNGramConfig):
275294 audio_file_paths .append (str (audio_file .absolute ()))
276295
277296 punctuation_capitalization = PunctuationCapitalization (cfg .text_processing .punctuation_marks )
278- target_transcripts = apply_text_processing (punctuation_capitalization , cfg , target_transcripts )
297+ if cfg .text_processing .do_lowercase :
298+ target_transcripts = punctuation_capitalization .do_lowercase (target_transcripts )
299+ if cfg .text_processing .rm_punctuation :
300+ target_transcripts = punctuation_capitalization .rm_punctuation (target_transcripts )
301+ if cfg .text_processing .separate_punctuation :
302+ target_transcripts = punctuation_capitalization .separate_punctuation (target_transcripts )
279303
280- if cfg .hyps_cache_file and os .path .exists (cfg .hyps_cache_file ):
281- logging .info (f"Found a pickle file of hypotheses at '{ cfg .hyps_cache_file } '." )
282- logging .info (f"Loading the cached pickle file of hypotheses from '{ cfg .hyps_cache_file } ' ..." )
283- with open (cfg .hyps_cache_file , 'rb' ) as probs_file :
284- all_hyps = pickle .load (probs_file )
304+ if cfg .probs_cache_file and os .path .exists (cfg .probs_cache_file ):
305+ logging .info (f"Found a pickle file of probabilities at '{ cfg .probs_cache_file } '." )
306+ logging .info (f"Loading the cached pickle file of probabilities from '{ cfg .probs_cache_file } ' ..." )
307+ with open (cfg .probs_cache_file , 'rb' ) as probs_file :
308+ all_probs = pickle .load (probs_file )
285309
286- if len (all_hyps ) != len (audio_file_paths ):
310+ if len (all_probs ) != len (audio_file_paths ):
287311 raise ValueError (
288- f"The number of samples in the hypotheses file '{ cfg .hyps_cache_file } ' does not "
289- f"match the manifest file. You may need to delete the hypotheses cached file."
312+ f"The number of samples in the probabilities file '{ cfg .probs_cache_file } ' does not "
313+ f"match the manifest file. You may need to delete the probabilities cached file."
290314 )
291315 else :
292316
293317 with torch .amp .autocast (asr_model .device .type , enabled = cfg .use_amp ):
294318 with torch .no_grad ():
295319 if isinstance (asr_model , EncDecHybridRNNTCTCModel ):
296320 asr_model .cur_decoder = 'ctc'
297- all_hyps = asr_model .transcribe (audio_file_paths , batch_size = cfg .batch_size )
321+ all_logits = asr_model .transcribe (audio_file_paths , batch_size = cfg .acoustic_batch_size , logprobs = True )
298322
299- if cfg .hyps_cache_file :
300- os .makedirs (os .path .split (cfg .hyps_cache_file )[0 ], exist_ok = True )
301- logging .info (f"Writing pickle files of hypotheses at '{ cfg .hyps_cache_file } '..." )
302- with open (cfg .hyps_cache_file , 'wb' ) as f_dump :
303- pickle .dump (all_hyps , f_dump )
323+ all_probs = all_logits
324+ if cfg .probs_cache_file :
325+ os .makedirs (os .path .split (cfg .probs_cache_file )[0 ], exist_ok = True )
326+ logging .info (f"Writing pickle files of probabilities at '{ cfg .probs_cache_file } '..." )
327+ with open (cfg .probs_cache_file , 'wb' ) as f_dump :
328+ pickle .dump (all_probs , f_dump )
304329
305330 wer_dist_greedy = 0
306331 cer_dist_greedy = 0
307332 words_count = 0
308333 chars_count = 0
309- for batch_idx , hyp in enumerate (all_hyps ):
310- pred_text = apply_text_processing (punctuation_capitalization , cfg , hyp .text )
334+ for batch_idx , probs in enumerate (all_probs ):
335+ preds = np .argmax (probs , axis = 1 )
336+ preds_tensor = torch .tensor (preds , device = 'cpu' ).unsqueeze (0 )
337+ if isinstance (asr_model , EncDecHybridRNNTCTCModel ):
338+ pred_text = asr_model .ctc_decoding .ctc_decoder_predictions_tensor (preds_tensor )[0 ]
339+ else :
340+ pred_text = asr_model ._wer .decoding .ctc_decoder_predictions_tensor (preds_tensor )[0 ]
341+
342+ if cfg .text_processing .do_lowercase :
343+ pred_text = punctuation_capitalization .do_lowercase ([pred_text ])[0 ]
344+ if cfg .text_processing .rm_punctuation :
345+ pred_text = punctuation_capitalization .rm_punctuation ([pred_text ])[0 ]
346+ if cfg .text_processing .separate_punctuation :
347+ pred_text = punctuation_capitalization .separate_punctuation ([pred_text ])[0 ]
311348
312349 pred_split_w = pred_text .split ()
313350 target_split_w = target_transcripts [batch_idx ].split ()
@@ -344,7 +381,7 @@ def main(cfg: EvalBeamSearchNGramConfig):
344381 best_wer_beam_size , best_cer_beam_size = None , None
345382 best_wer_alpha , best_cer_alpha = None , None
346383 best_wer_beta , best_cer_beta = None , None
347- best_wer , best_cer = float ( "inf" ), float ( "inf" )
384+ best_wer , best_cer = 1e6 , 1e6
348385
349386 logging .info (f"==============================Starting the beam search decoding===============================" )
350387 logging .info (f"Grid search size: { len (hp_grid )} " )
@@ -363,33 +400,31 @@ def main(cfg: EvalBeamSearchNGramConfig):
363400 preds_output_file = None
364401
365402 candidate_wer , candidate_cer = beam_search_eval (
366- audio_file_paths ,
367403 asr_model ,
368404 cfg ,
405+ all_probs = all_probs ,
369406 target_transcripts = target_transcripts ,
370407 preds_output_file = preds_output_file ,
371408 lm_path = lm_path ,
372409 beam_width = hp ["beam_width" ],
373410 beam_alpha = hp ["beam_alpha" ],
374411 beam_beta = hp ["beam_beta" ],
412+ beam_batch_size = cfg .beam_batch_size ,
413+ progress_bar = True ,
375414 punctuation_capitalization = punctuation_capitalization ,
376415 )
377416
378417 if candidate_cer < best_cer :
379- best_cer_beam_size , best_cer_alpha , best_cer_beta , best_cer = (
380- hp ["beam_width" ],
381- hp ["beam_alpha" ],
382- hp ["beam_beta" ],
383- candidate_cer ,
384- )
418+ best_cer_beam_size = hp ["beam_width" ]
419+ best_cer_alpha = hp ["beam_alpha" ]
420+ best_cer_beta = hp ["beam_beta" ]
421+ best_cer = candidate_cer
385422
386423 if candidate_wer < best_wer :
387- best_wer_beam_size , best_wer_alpha , best_wer_beta , best_wer = (
388- hp ["beam_width" ],
389- hp ["beam_alpha" ],
390- hp ["beam_beta" ],
391- candidate_wer ,
392- )
424+ best_wer_beam_size = hp ["beam_width" ]
425+ best_wer_alpha = hp ["beam_alpha" ]
426+ best_wer_beta = hp ["beam_beta" ]
427+ best_wer = candidate_wer
393428
394429 logging .info (
395430 f'Best WER Candidate = { best_wer :.2%} :: Beam size = { best_wer_beam_size } , '
0 commit comments