Skip to content

Commit 8d2f2ad

Browse files
lilithgrigoryanko3n1g
authored andcommitted
fix eval_beamsearch_ngram_ctc script (NVIDIA-NeMo#14238)
* add fix Signed-off-by: lilithgrigoryan <[email protected]> * minor fix Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * refactor Signed-off-by: lilithgrigoryan <[email protected]> * minor fix Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * minor fix Signed-off-by: lilithgrigoryan <[email protected]> * fix Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * restore gitignore Signed-off-by: lilithgrigoryan <[email protected]> * restore gitignore Signed-off-by: lilithgrigoryan <[email protected]> --------- Signed-off-by: lilithgrigoryan <[email protected]> Signed-off-by: lilithgrigoryan <[email protected]> Co-authored-by: lilithgrigoryan <[email protected]> Co-authored-by: oliver könig <[email protected]> Signed-off-by: Amir Hussein <[email protected]>
1 parent 863fefe commit 8d2f2ad

File tree

1 file changed

+79
-114
lines changed

1 file changed

+79
-114
lines changed

scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py

Lines changed: 79 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from nemo.core.config import hydra_runner
7979
from nemo.utils import logging
8080

81+
8182
# fmt: off
8283

8384

@@ -93,11 +94,10 @@ class EvalBeamSearchNGramConfig:
9394
input_manifest: str = MISSING # The manifest file of the evaluation set
9495
kenlm_model_file: Optional[str] = None # The path of the KenLM binary model file
9596
preds_output_folder: Optional[str] = None # The optional folder where the predictions are stored
96-
probs_cache_file: Optional[str] = None # The cache file for storing the logprobs of the model
97+
hyps_cache_file: Optional[str] = None # The cache file for storing the logprobs of the model
9798

9899
# Parameters for inference
99-
acoustic_batch_size: int = 16 # The batch size to calculate log probabilities
100-
beam_batch_size: int = 128 # The batch size to be used for beam search decoding
100+
batch_size: int = 16 # The batch size
101101
device: str = "cuda" # The device to load the model onto to calculate log probabilities
102102
use_amp: bool = False # Whether to use AMP if available to calculate log probabilities
103103

@@ -123,18 +123,31 @@ class EvalBeamSearchNGramConfig:
123123
# fmt: on
124124

125125

126+
def apply_text_processing(
127+
punctuation_capitalization: PunctuationCapitalization, cfg: EvalBeamSearchNGramConfig, text: List[str] | str
128+
) -> List[str] | str:
129+
is_list = isinstance(text, list)
130+
text_arr = text if is_list else [text]
131+
if cfg.text_processing.do_lowercase:
132+
text_arr = punctuation_capitalization.do_lowercase(text_arr)
133+
if cfg.text_processing.rm_punctuation:
134+
text_arr = punctuation_capitalization.rm_punctuation(text_arr)
135+
if cfg.text_processing.separate_punctuation:
136+
text_arr = punctuation_capitalization.separate_punctuation(text_arr)
137+
138+
return text_arr if is_list else text_arr[0]
139+
140+
126141
def beam_search_eval(
142+
audio_filepaths,
127143
model: nemo_asr.models.ASRModel,
128144
cfg: EvalBeamSearchNGramConfig,
129-
all_probs: List[torch.Tensor],
130145
target_transcripts: List[str],
131146
preds_output_file: str = None,
132147
lm_path: str = None,
133148
beam_alpha: float = 1.0,
134149
beam_beta: float = 0.0,
135150
beam_width: int = 128,
136-
beam_batch_size: int = 128,
137-
progress_bar: bool = True,
138151
punctuation_capitalization: PunctuationCapitalization = None,
139152
):
140153
level = logging.getEffectiveLevel()
@@ -159,80 +172,47 @@ def beam_search_eval(
159172
# Update model's decoding strategy
160173
if isinstance(model, EncDecHybridRNNTCTCModel):
161174
model.change_decoding_strategy(model.cfg.decoding, decoder_type='ctc')
162-
decoding = model.ctc_decoding
163175
else:
164176
model.change_decoding_strategy(model.cfg.decoding)
165-
decoding = model.decoding
166177
logging.setLevel(level)
167178

179+
all_hyps = model.transcribe(audio_filepaths, cfg.batch_size)
180+
168181
wer_dist_first = cer_dist_first = 0
169182
wer_dist_best = cer_dist_best = 0
170183
words_count = 0
171184
chars_count = 0
172-
sample_idx = 0
173185
if preds_output_file:
174186
out_file = open(preds_output_file, 'w', encoding='utf_8', newline='\n')
175187

176-
if progress_bar:
177-
it = tqdm(
178-
range(int(np.ceil(len(all_probs) / beam_batch_size))),
179-
desc=f"Beam search decoding with width={beam_width}, alpha={beam_alpha}, beta={beam_beta}",
180-
ncols=120,
181-
)
182-
else:
183-
it = range(int(np.ceil(len(all_probs) / beam_batch_size)))
184-
for batch_idx in it:
185-
# disabling type checking
186-
probs_batch = all_probs[batch_idx * beam_batch_size : (batch_idx + 1) * beam_batch_size]
187-
probs_lens = torch.tensor([prob.shape[0] for prob in probs_batch])
188-
with torch.no_grad():
189-
packed_batch = torch.zeros(len(probs_batch), max(probs_lens), probs_batch[0].shape[-1], device='cpu')
190-
191-
for prob_index in range(len(probs_batch)):
192-
packed_batch[prob_index, : probs_lens[prob_index], :] = torch.tensor(
193-
probs_batch[prob_index], device=packed_batch.device, dtype=packed_batch.dtype
194-
)
188+
for batch_idx, nbest_hyp in enumerate(all_hyps):
189+
target = target_transcripts[batch_idx]
190+
target_split_w = target.split()
191+
target_split_c = list(target)
192+
words_count += len(target_split_w)
193+
chars_count += len(target_split_c)
194+
wer_dist_min = cer_dist_min = float("inf")
195+
for candidate_idx, candidate in enumerate(nbest_hyp):
196+
pred_text = apply_text_processing(punctuation_capitalization, cfg, candidate.text)
195197

196-
beams_batch = decoding.ctc_decoder_predictions_tensor(
197-
packed_batch,
198-
decoder_lengths=probs_lens,
199-
return_hypotheses=True,
200-
)
198+
pred_split_w = pred_text.split()
199+
wer_dist = editdistance.eval(target_split_w, pred_split_w)
200+
pred_split_c = list(pred_text)
201+
cer_dist = editdistance.eval(target_split_c, pred_split_c)
201202

202-
for beams_idx, beams in enumerate(beams_batch):
203-
target = target_transcripts[sample_idx + beams_idx]
204-
target_split_w = target.split()
205-
target_split_c = list(target)
206-
words_count += len(target_split_w)
207-
chars_count += len(target_split_c)
208-
wer_dist_min = cer_dist_min = 10000
209-
for candidate_idx, candidate in enumerate(beams): # type: (int, ctc_beam_decoding.rnnt_utils.Hypothesis)
210-
pred_text = candidate.text
211-
if cfg.text_processing.do_lowercase:
212-
pred_text = punctuation_capitalization.do_lowercase([pred_text])[0]
213-
if cfg.text_processing.rm_punctuation:
214-
pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0]
215-
if cfg.text_processing.separate_punctuation:
216-
pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0]
217-
pred_split_w = pred_text.split()
218-
wer_dist = editdistance.eval(target_split_w, pred_split_w)
219-
pred_split_c = list(pred_text)
220-
cer_dist = editdistance.eval(target_split_c, pred_split_c)
221-
222-
wer_dist_min = min(wer_dist_min, wer_dist)
223-
cer_dist_min = min(cer_dist_min, cer_dist)
224-
225-
if candidate_idx == 0:
226-
# first candidate
227-
wer_dist_first += wer_dist
228-
cer_dist_first += cer_dist
229-
230-
score = candidate.score
231-
if preds_output_file:
232-
out_file.write('{}\t{}\n'.format(pred_text, score))
233-
wer_dist_best += wer_dist_min
234-
cer_dist_best += cer_dist_min
235-
sample_idx += len(probs_batch)
203+
wer_dist_min = min(wer_dist_min, wer_dist)
204+
cer_dist_min = min(cer_dist_min, cer_dist)
205+
206+
if candidate_idx == 0:
207+
# first candidate
208+
wer_dist_first += wer_dist
209+
cer_dist_first += cer_dist
210+
211+
score = candidate.score
212+
if preds_output_file:
213+
out_file.write('{}\t{}\n'.format(pred_text, score))
214+
wer_dist_best += wer_dist_min
215+
cer_dist_best += cer_dist_min
236216

237217
if preds_output_file:
238218
out_file.close()
@@ -255,6 +235,7 @@ def beam_search_eval(
255235
wer_dist_best / words_count, cer_dist_best / chars_count
256236
)
257237
)
238+
258239
logging.info(f"=================================================================================")
259240

260241
return wer_dist_first / words_count, cer_dist_first / chars_count
@@ -294,57 +275,39 @@ def main(cfg: EvalBeamSearchNGramConfig):
294275
audio_file_paths.append(str(audio_file.absolute()))
295276

296277
punctuation_capitalization = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
297-
if cfg.text_processing.do_lowercase:
298-
target_transcripts = punctuation_capitalization.do_lowercase(target_transcripts)
299-
if cfg.text_processing.rm_punctuation:
300-
target_transcripts = punctuation_capitalization.rm_punctuation(target_transcripts)
301-
if cfg.text_processing.separate_punctuation:
302-
target_transcripts = punctuation_capitalization.separate_punctuation(target_transcripts)
278+
target_transcripts = apply_text_processing(punctuation_capitalization, cfg, target_transcripts)
303279

304-
if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file):
305-
logging.info(f"Found a pickle file of probabilities at '{cfg.probs_cache_file}'.")
306-
logging.info(f"Loading the cached pickle file of probabilities from '{cfg.probs_cache_file}' ...")
307-
with open(cfg.probs_cache_file, 'rb') as probs_file:
308-
all_probs = pickle.load(probs_file)
280+
if cfg.hyps_cache_file and os.path.exists(cfg.hyps_cache_file):
281+
logging.info(f"Found a pickle file of hypotheses at '{cfg.hyps_cache_file}'.")
282+
logging.info(f"Loading the cached pickle file of hypotheses from '{cfg.hyps_cache_file}' ...")
283+
with open(cfg.hyps_cache_file, 'rb') as probs_file:
284+
all_hyps = pickle.load(probs_file)
309285

310-
if len(all_probs) != len(audio_file_paths):
286+
if len(all_hyps) != len(audio_file_paths):
311287
raise ValueError(
312-
f"The number of samples in the probabilities file '{cfg.probs_cache_file}' does not "
313-
f"match the manifest file. You may need to delete the probabilities cached file."
288+
f"The number of samples in the hypotheses file '{cfg.hyps_cache_file}' does not "
289+
f"match the manifest file. You may need to delete the hypotheses cached file."
314290
)
315291
else:
316292

317293
with torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp):
318294
with torch.no_grad():
319295
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
320296
asr_model.cur_decoder = 'ctc'
321-
all_logits = asr_model.transcribe(audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True)
297+
all_hyps = asr_model.transcribe(audio_file_paths, batch_size=cfg.batch_size)
322298

323-
all_probs = all_logits
324-
if cfg.probs_cache_file:
325-
os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True)
326-
logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...")
327-
with open(cfg.probs_cache_file, 'wb') as f_dump:
328-
pickle.dump(all_probs, f_dump)
299+
if cfg.hyps_cache_file:
300+
os.makedirs(os.path.split(cfg.hyps_cache_file)[0], exist_ok=True)
301+
logging.info(f"Writing pickle files of hypotheses at '{cfg.hyps_cache_file}'...")
302+
with open(cfg.hyps_cache_file, 'wb') as f_dump:
303+
pickle.dump(all_hyps, f_dump)
329304

330305
wer_dist_greedy = 0
331306
cer_dist_greedy = 0
332307
words_count = 0
333308
chars_count = 0
334-
for batch_idx, probs in enumerate(all_probs):
335-
preds = np.argmax(probs, axis=1)
336-
preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0)
337-
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
338-
pred_text = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(preds_tensor)[0]
339-
else:
340-
pred_text = asr_model._wer.decoding.ctc_decoder_predictions_tensor(preds_tensor)[0]
341-
342-
if cfg.text_processing.do_lowercase:
343-
pred_text = punctuation_capitalization.do_lowercase([pred_text])[0]
344-
if cfg.text_processing.rm_punctuation:
345-
pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0]
346-
if cfg.text_processing.separate_punctuation:
347-
pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0]
309+
for batch_idx, hyp in enumerate(all_hyps):
310+
pred_text = apply_text_processing(punctuation_capitalization, cfg, hyp.text)
348311

349312
pred_split_w = pred_text.split()
350313
target_split_w = target_transcripts[batch_idx].split()
@@ -381,7 +344,7 @@ def main(cfg: EvalBeamSearchNGramConfig):
381344
best_wer_beam_size, best_cer_beam_size = None, None
382345
best_wer_alpha, best_cer_alpha = None, None
383346
best_wer_beta, best_cer_beta = None, None
384-
best_wer, best_cer = 1e6, 1e6
347+
best_wer, best_cer = float("inf"), float("inf")
385348

386349
logging.info(f"==============================Starting the beam search decoding===============================")
387350
logging.info(f"Grid search size: {len(hp_grid)}")
@@ -400,31 +363,33 @@ def main(cfg: EvalBeamSearchNGramConfig):
400363
preds_output_file = None
401364

402365
candidate_wer, candidate_cer = beam_search_eval(
366+
audio_file_paths,
403367
asr_model,
404368
cfg,
405-
all_probs=all_probs,
406369
target_transcripts=target_transcripts,
407370
preds_output_file=preds_output_file,
408371
lm_path=lm_path,
409372
beam_width=hp["beam_width"],
410373
beam_alpha=hp["beam_alpha"],
411374
beam_beta=hp["beam_beta"],
412-
beam_batch_size=cfg.beam_batch_size,
413-
progress_bar=True,
414375
punctuation_capitalization=punctuation_capitalization,
415376
)
416377

417378
if candidate_cer < best_cer:
418-
best_cer_beam_size = hp["beam_width"]
419-
best_cer_alpha = hp["beam_alpha"]
420-
best_cer_beta = hp["beam_beta"]
421-
best_cer = candidate_cer
379+
best_cer_beam_size, best_cer_alpha, best_cer_beta, best_cer = (
380+
hp["beam_width"],
381+
hp["beam_alpha"],
382+
hp["beam_beta"],
383+
candidate_cer,
384+
)
422385

423386
if candidate_wer < best_wer:
424-
best_wer_beam_size = hp["beam_width"]
425-
best_wer_alpha = hp["beam_alpha"]
426-
best_wer_beta = hp["beam_beta"]
427-
best_wer = candidate_wer
387+
best_wer_beam_size, best_wer_alpha, best_wer_beta, best_wer = (
388+
hp["beam_width"],
389+
hp["beam_alpha"],
390+
hp["beam_beta"],
391+
candidate_wer,
392+
)
428393

429394
logging.info(
430395
f'Best WER Candidate = {best_wer:.2%} :: Beam size = {best_wer_beam_size}, '

0 commit comments

Comments
 (0)