Skip to content

Commit 02f4e12

Browse files
Merge branch 'lgrigoryan/fix-eval_beamsearch_ngram_ctc' of https://github.com/NVIDIA/NeMo into lgrigoryan/fix-eval_beamsearch_ngram_ctc
2 parents a7eaeea + 9f2b314 commit 02f4e12

File tree

1 file changed

+114
-79
lines changed

1 file changed

+114
-79
lines changed

scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py

Lines changed: 114 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
from nemo.core.config import hydra_runner
7979
from nemo.utils import logging
8080

81-
8281
# fmt: off
8382

8483

@@ -94,10 +93,11 @@ class EvalBeamSearchNGramConfig:
9493
input_manifest: str = MISSING # The manifest file of the evaluation set
9594
kenlm_model_file: Optional[str] = None # The path of the KenLM binary model file
9695
preds_output_folder: Optional[str] = None # The optional folder where the predictions are stored
97-
hyps_cache_file: Optional[str] = None # The cache file for storing the logprobs of the model
96+
probs_cache_file: Optional[str] = None # The cache file for storing the logprobs of the model
9897

9998
# Parameters for inference
100-
batch_size: int = 16 # The batch size
99+
acoustic_batch_size: int = 16 # The batch size to calculate log probabilities
100+
beam_batch_size: int = 128 # The batch size to be used for beam search decoding
101101
device: str = "cuda" # The device to load the model onto to calculate log probabilities
102102
use_amp: bool = False # Whether to use AMP if available to calculate log probabilities
103103

@@ -123,31 +123,18 @@ class EvalBeamSearchNGramConfig:
123123
# fmt: on
124124

125125

126-
def apply_text_processing(
127-
punctuation_capitalization: PunctuationCapitalization, cfg: EvalBeamSearchNGramConfig, text: List[str] | str
128-
) -> List[str] | str:
129-
is_list = isinstance(text, list)
130-
text_arr = text if is_list else [text]
131-
if cfg.text_processing.do_lowercase:
132-
text_arr = punctuation_capitalization.do_lowercase(text_arr)
133-
if cfg.text_processing.rm_punctuation:
134-
text_arr = punctuation_capitalization.rm_punctuation(text_arr)
135-
if cfg.text_processing.separate_punctuation:
136-
text_arr = punctuation_capitalization.separate_punctuation(text_arr)
137-
138-
return text_arr if is_list else text_arr[0]
139-
140-
141126
def beam_search_eval(
142-
audio_filepaths,
143127
model: nemo_asr.models.ASRModel,
144128
cfg: EvalBeamSearchNGramConfig,
129+
all_probs: List[torch.Tensor],
145130
target_transcripts: List[str],
146131
preds_output_file: str = None,
147132
lm_path: str = None,
148133
beam_alpha: float = 1.0,
149134
beam_beta: float = 0.0,
150135
beam_width: int = 128,
136+
beam_batch_size: int = 128,
137+
progress_bar: bool = True,
151138
punctuation_capitalization: PunctuationCapitalization = None,
152139
):
153140
level = logging.getEffectiveLevel()
@@ -172,47 +159,80 @@ def beam_search_eval(
172159
# Update model's decoding strategy
173160
if isinstance(model, EncDecHybridRNNTCTCModel):
174161
model.change_decoding_strategy(model.cfg.decoding, decoder_type='ctc')
162+
decoding = model.ctc_decoding
175163
else:
176164
model.change_decoding_strategy(model.cfg.decoding)
165+
decoding = model.decoding
177166
logging.setLevel(level)
178167

179-
all_hyps = model.transcribe(audio_filepaths, cfg.batch_size)
180-
181168
wer_dist_first = cer_dist_first = 0
182169
wer_dist_best = cer_dist_best = 0
183170
words_count = 0
184171
chars_count = 0
172+
sample_idx = 0
185173
if preds_output_file:
186174
out_file = open(preds_output_file, 'w', encoding='utf_8', newline='\n')
187175

188-
for batch_idx, nbest_hyp in enumerate(all_hyps):
189-
target = target_transcripts[batch_idx]
190-
target_split_w = target.split()
191-
target_split_c = list(target)
192-
words_count += len(target_split_w)
193-
chars_count += len(target_split_c)
194-
wer_dist_min = cer_dist_min = float("inf")
195-
for candidate_idx, candidate in enumerate(nbest_hyp):
196-
pred_text = apply_text_processing(punctuation_capitalization, cfg, candidate.text)
197-
198-
pred_split_w = pred_text.split()
199-
wer_dist = editdistance.eval(target_split_w, pred_split_w)
200-
pred_split_c = list(pred_text)
201-
cer_dist = editdistance.eval(target_split_c, pred_split_c)
202-
203-
wer_dist_min = min(wer_dist_min, wer_dist)
204-
cer_dist_min = min(cer_dist_min, cer_dist)
176+
if progress_bar:
177+
it = tqdm(
178+
range(int(np.ceil(len(all_probs) / beam_batch_size))),
179+
desc=f"Beam search decoding with width={beam_width}, alpha={beam_alpha}, beta={beam_beta}",
180+
ncols=120,
181+
)
182+
else:
183+
it = range(int(np.ceil(len(all_probs) / beam_batch_size)))
184+
for batch_idx in it:
185+
# disabling type checking
186+
probs_batch = all_probs[batch_idx * beam_batch_size : (batch_idx + 1) * beam_batch_size]
187+
probs_lens = torch.tensor([prob.shape[0] for prob in probs_batch])
188+
with torch.no_grad():
189+
packed_batch = torch.zeros(len(probs_batch), max(probs_lens), probs_batch[0].shape[-1], device='cpu')
190+
191+
for prob_index in range(len(probs_batch)):
192+
packed_batch[prob_index, : probs_lens[prob_index], :] = torch.tensor(
193+
probs_batch[prob_index], device=packed_batch.device, dtype=packed_batch.dtype
194+
)
205195

206-
if candidate_idx == 0:
207-
# first candidate
208-
wer_dist_first += wer_dist
209-
cer_dist_first += cer_dist
196+
beams_batch = decoding.ctc_decoder_predictions_tensor(
197+
packed_batch,
198+
decoder_lengths=probs_lens,
199+
return_hypotheses=True,
200+
)
210201

211-
score = candidate.score
212-
if preds_output_file:
213-
out_file.write('{}\t{}\n'.format(pred_text, score))
214-
wer_dist_best += wer_dist_min
215-
cer_dist_best += cer_dist_min
202+
for beams_idx, beams in enumerate(beams_batch):
203+
target = target_transcripts[sample_idx + beams_idx]
204+
target_split_w = target.split()
205+
target_split_c = list(target)
206+
words_count += len(target_split_w)
207+
chars_count += len(target_split_c)
208+
wer_dist_min = cer_dist_min = 10000
209+
for candidate_idx, candidate in enumerate(beams): # type: (int, ctc_beam_decoding.rnnt_utils.Hypothesis)
210+
pred_text = candidate.text
211+
if cfg.text_processing.do_lowercase:
212+
pred_text = punctuation_capitalization.do_lowercase([pred_text])[0]
213+
if cfg.text_processing.rm_punctuation:
214+
pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0]
215+
if cfg.text_processing.separate_punctuation:
216+
pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0]
217+
pred_split_w = pred_text.split()
218+
wer_dist = editdistance.eval(target_split_w, pred_split_w)
219+
pred_split_c = list(pred_text)
220+
cer_dist = editdistance.eval(target_split_c, pred_split_c)
221+
222+
wer_dist_min = min(wer_dist_min, wer_dist)
223+
cer_dist_min = min(cer_dist_min, cer_dist)
224+
225+
if candidate_idx == 0:
226+
# first candidate
227+
wer_dist_first += wer_dist
228+
cer_dist_first += cer_dist
229+
230+
score = candidate.score
231+
if preds_output_file:
232+
out_file.write('{}\t{}\n'.format(pred_text, score))
233+
wer_dist_best += wer_dist_min
234+
cer_dist_best += cer_dist_min
235+
sample_idx += len(probs_batch)
216236

217237
if preds_output_file:
218238
out_file.close()
@@ -235,7 +255,6 @@ def beam_search_eval(
235255
wer_dist_best / words_count, cer_dist_best / chars_count
236256
)
237257
)
238-
239258
logging.info(f"=================================================================================")
240259

241260
return wer_dist_first / words_count, cer_dist_first / chars_count
@@ -275,39 +294,57 @@ def main(cfg: EvalBeamSearchNGramConfig):
275294
audio_file_paths.append(str(audio_file.absolute()))
276295

277296
punctuation_capitalization = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
278-
target_transcripts = apply_text_processing(punctuation_capitalization, cfg, target_transcripts)
297+
if cfg.text_processing.do_lowercase:
298+
target_transcripts = punctuation_capitalization.do_lowercase(target_transcripts)
299+
if cfg.text_processing.rm_punctuation:
300+
target_transcripts = punctuation_capitalization.rm_punctuation(target_transcripts)
301+
if cfg.text_processing.separate_punctuation:
302+
target_transcripts = punctuation_capitalization.separate_punctuation(target_transcripts)
279303

280-
if cfg.hyps_cache_file and os.path.exists(cfg.hyps_cache_file):
281-
logging.info(f"Found a pickle file of hypotheses at '{cfg.hyps_cache_file}'.")
282-
logging.info(f"Loading the cached pickle file of hypotheses from '{cfg.hyps_cache_file}' ...")
283-
with open(cfg.hyps_cache_file, 'rb') as probs_file:
284-
all_hyps = pickle.load(probs_file)
304+
if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file):
305+
logging.info(f"Found a pickle file of probabilities at '{cfg.probs_cache_file}'.")
306+
logging.info(f"Loading the cached pickle file of probabilities from '{cfg.probs_cache_file}' ...")
307+
with open(cfg.probs_cache_file, 'rb') as probs_file:
308+
all_probs = pickle.load(probs_file)
285309

286-
if len(all_hyps) != len(audio_file_paths):
310+
if len(all_probs) != len(audio_file_paths):
287311
raise ValueError(
288-
f"The number of samples in the hypotheses file '{cfg.hyps_cache_file}' does not "
289-
f"match the manifest file. You may need to delete the hypotheses cached file."
312+
f"The number of samples in the probabilities file '{cfg.probs_cache_file}' does not "
313+
f"match the manifest file. You may need to delete the probabilities cached file."
290314
)
291315
else:
292316

293317
with torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp):
294318
with torch.no_grad():
295319
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
296320
asr_model.cur_decoder = 'ctc'
297-
all_hyps = asr_model.transcribe(audio_file_paths, batch_size=cfg.batch_size)
321+
all_logits = asr_model.transcribe(audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True)
298322

299-
if cfg.hyps_cache_file:
300-
os.makedirs(os.path.split(cfg.hyps_cache_file)[0], exist_ok=True)
301-
logging.info(f"Writing pickle files of hypotheses at '{cfg.hyps_cache_file}'...")
302-
with open(cfg.hyps_cache_file, 'wb') as f_dump:
303-
pickle.dump(all_hyps, f_dump)
323+
all_probs = all_logits
324+
if cfg.probs_cache_file:
325+
os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True)
326+
logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...")
327+
with open(cfg.probs_cache_file, 'wb') as f_dump:
328+
pickle.dump(all_probs, f_dump)
304329

305330
wer_dist_greedy = 0
306331
cer_dist_greedy = 0
307332
words_count = 0
308333
chars_count = 0
309-
for batch_idx, hyp in enumerate(all_hyps):
310-
pred_text = apply_text_processing(punctuation_capitalization, cfg, hyp.text)
334+
for batch_idx, probs in enumerate(all_probs):
335+
preds = np.argmax(probs, axis=1)
336+
preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0)
337+
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
338+
pred_text = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(preds_tensor)[0]
339+
else:
340+
pred_text = asr_model._wer.decoding.ctc_decoder_predictions_tensor(preds_tensor)[0]
341+
342+
if cfg.text_processing.do_lowercase:
343+
pred_text = punctuation_capitalization.do_lowercase([pred_text])[0]
344+
if cfg.text_processing.rm_punctuation:
345+
pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0]
346+
if cfg.text_processing.separate_punctuation:
347+
pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0]
311348

312349
pred_split_w = pred_text.split()
313350
target_split_w = target_transcripts[batch_idx].split()
@@ -344,7 +381,7 @@ def main(cfg: EvalBeamSearchNGramConfig):
344381
best_wer_beam_size, best_cer_beam_size = None, None
345382
best_wer_alpha, best_cer_alpha = None, None
346383
best_wer_beta, best_cer_beta = None, None
347-
best_wer, best_cer = float("inf"), float("inf")
384+
best_wer, best_cer = 1e6, 1e6
348385

349386
logging.info(f"==============================Starting the beam search decoding===============================")
350387
logging.info(f"Grid search size: {len(hp_grid)}")
@@ -363,33 +400,31 @@ def main(cfg: EvalBeamSearchNGramConfig):
363400
preds_output_file = None
364401

365402
candidate_wer, candidate_cer = beam_search_eval(
366-
audio_file_paths,
367403
asr_model,
368404
cfg,
405+
all_probs=all_probs,
369406
target_transcripts=target_transcripts,
370407
preds_output_file=preds_output_file,
371408
lm_path=lm_path,
372409
beam_width=hp["beam_width"],
373410
beam_alpha=hp["beam_alpha"],
374411
beam_beta=hp["beam_beta"],
412+
beam_batch_size=cfg.beam_batch_size,
413+
progress_bar=True,
375414
punctuation_capitalization=punctuation_capitalization,
376415
)
377416

378417
if candidate_cer < best_cer:
379-
best_cer_beam_size, best_cer_alpha, best_cer_beta, best_cer = (
380-
hp["beam_width"],
381-
hp["beam_alpha"],
382-
hp["beam_beta"],
383-
candidate_cer,
384-
)
418+
best_cer_beam_size = hp["beam_width"]
419+
best_cer_alpha = hp["beam_alpha"]
420+
best_cer_beta = hp["beam_beta"]
421+
best_cer = candidate_cer
385422

386423
if candidate_wer < best_wer:
387-
best_wer_beam_size, best_wer_alpha, best_wer_beta, best_wer = (
388-
hp["beam_width"],
389-
hp["beam_alpha"],
390-
hp["beam_beta"],
391-
candidate_wer,
392-
)
424+
best_wer_beam_size = hp["beam_width"]
425+
best_wer_alpha = hp["beam_alpha"]
426+
best_wer_beta = hp["beam_beta"]
427+
best_wer = candidate_wer
393428

394429
logging.info(
395430
f'Best WER Candidate = {best_wer:.2%} :: Beam size = {best_wer_beam_size}, '

0 commit comments

Comments
 (0)