Skip to content

Commit 6a7e577

Browse files
lilithgrigoryanko3n1gartbataev
authored andcommitted
add CTC batched beam search (NVIDIA-NeMo#13337)
* add ctc beam decoding Signed-off-by: lilithgrigoryan <[email protected]> * add utils Signed-off-by: lilithgrigoryan <[email protected]> * first working Signed-off-by: lilithgrigoryan <[email protected]> * working cuda graphs Signed-off-by: lilithgrigoryan <[email protected]> * fix bugs with cudagraohs Signed-off-by: lilithgrigoryan <[email protected]> * working Signed-off-by: lilithgrigoryan <[email protected]> * small fix Signed-off-by: lilithgrigoryan <[email protected]> * minor fix Signed-off-by: lilithgrigoryan <[email protected]> * add logging Signed-off-by: lilithgrigoryan <[email protected]> * add print Signed-off-by: lilithgrigoryan <[email protected]> * to log sum exp Signed-off-by: lilithgrigoryan <[email protected]> * back to max score Signed-off-by: lilithgrigoryan <[email protected]> * fix bug in cudagraphs, save before refactor Signed-off-by: lilithgrigoryan <[email protected]> * rm log10 Signed-off-by: lilithgrigoryan <[email protected]> * rm prints Signed-off-by: lilithgrigoryan <[email protected]> * add reallocation Signed-off-by: lilithgrigoryan <[email protected]> * rm logprobs from state Signed-off-by: lilithgrigoryan <[email protected]> * rm nexts from state Signed-off-by: lilithgrigoryan <[email protected]> * rm prev lm states Signed-off-by: lilithgrigoryan <[email protected]> * small clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up cuda graphs Signed-off-by: lilithgrigoryan <[email protected]> * cudagraph working Signed-off-by: lilithgrigoryan <[email protected]> * clean up torch working Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * rm files Signed-off-by: lilithgrigoryan <[email protected]> * save Signed-off-by: lilithgrigoryan <[email protected]> * add flatten Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * add timestamps Signed-off-by: lilithgrigoryan <[email protected]> * rm file Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * rename file Signed-off-by: lilithgrigoryan <[email protected]> * add batched beam tests Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * add tests Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * changed return type Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * minor changes Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * renamed variables Signed-off-by: lilithgrigoryan <[email protected]> * changed is_tdt to model_type Signed-off-by: lilithgrigoryan <[email protected]> * unified batched beam hyps Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * Update cuda_program_name Co-authored-by: Vladimir Bataev <[email protected]> Signed-off-by: lilithgrigoryan <[email protected]> * clean up and and commments Signed-off-by: lilithgrigoryan <[email protected]> * clean up and small fixes Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * fix Signed-off-by: lilithgrigoryan <[email protected]> * fix tests Signed-off-by: lilithgrigoryan <[email protected]> * added check on model type Signed-off-by: lilithgrigoryan <[email protected]> * minor change Signed-off-by: lilithgrigoryan <[email protected]> * rm repetitions LM scoring Signed-off-by: lilithgrigoryan <[email protected]> * add enum model type Signed-off-by: lilithgrigoryan <[email protected]> * add enum model type Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * fix lm repetitions for cudahraphs Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> --------- Signed-off-by: lilithgrigoryan <[email protected]> Signed-off-by: lilithgrigoryan <[email protected]> Signed-off-by: lilithgrigoryan <[email protected]> Co-authored-by: lilithgrigoryan <[email protected]> Co-authored-by: oliver könig <[email protected]> Co-authored-by: Vladimir Bataev <[email protected]> Signed-off-by: Amir Hussein <[email protected]>
1 parent 52315a0 commit 6a7e577

File tree

13 files changed

+1615
-129
lines changed

13 files changed

+1615
-129
lines changed

nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py

Lines changed: 696 additions & 0 deletions
Large diffs are not rendered by default.

nemo/collections/asr/parts/submodules/ctc_beam_decoding.py

Lines changed: 124 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323

2424
from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig
25+
from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer
2526
from nemo.collections.asr.parts.submodules.ngram_lm import DEFAULT_TOKEN_OFFSET
2627
from nemo.collections.asr.parts.submodules.wfst_decoder import RivaDecoderConfig, WfstNbestHypothesis
2728
from nemo.collections.asr.parts.utils import rnnt_utils
@@ -204,7 +205,7 @@ def __call__(self, *args, **kwargs):
204205

205206

206207
class BeamCTCInfer(AbstractBeamCTCInfer):
207-
"""A greedy CTC decoder.
208+
"""A beam CTC decoder.
208209
209210
Provides a common abstraction for sample level and batch level greedy decoding.
210211
@@ -227,9 +228,9 @@ def __init__(
227228
return_best_hypothesis: bool = True,
228229
preserve_alignments: bool = False,
229230
compute_timestamps: bool = False,
230-
beam_alpha: float = 1.0,
231+
ngram_lm_alpha: float = 0.3,
231232
beam_beta: float = 0.0,
232-
kenlm_path: str = None,
233+
ngram_lm_model: str = None,
233234
flashlight_cfg: Optional['FlashlightConfig'] = None,
234235
pyctcdecode_cfg: Optional['PyCTCDecodeConfig'] = None,
235236
):
@@ -260,11 +261,11 @@ def __init__(
260261
# Log the beam search algorithm
261262
logging.info(f"Beam search algorithm: {search_type}")
262263

263-
self.beam_alpha = beam_alpha
264+
self.ngram_lm_alpha = ngram_lm_alpha
264265
self.beam_beta = beam_beta
265266

266267
# Default beam search args
267-
self.kenlm_path = kenlm_path
268+
self.ngram_lm_model = ngram_lm_model
268269

269270
# PyCTCDecode params
270271
if pyctcdecode_cfg is None:
@@ -349,9 +350,9 @@ def default_beam_search(
349350

350351
if self.default_beam_scorer is None:
351352
# Check for filepath
352-
if self.kenlm_path is None or not os.path.exists(self.kenlm_path):
353+
if self.ngram_lm_model is None or not os.path.exists(self.ngram_lm_model):
353354
raise FileNotFoundError(
354-
f"KenLM binary file not found at : {self.kenlm_path}. "
355+
f"KenLM binary file not found at : {self.ngram_lm_model}. "
355356
f"Please set a valid path in the decoding config."
356357
)
357358

@@ -367,9 +368,9 @@ def default_beam_search(
367368

368369
self.default_beam_scorer = BeamSearchDecoderWithLM(
369370
vocab=vocab,
370-
lm_path=self.kenlm_path,
371+
lm_path=self.ngram_lm_model,
371372
beam_width=self.beam_size,
372-
alpha=self.beam_alpha,
373+
alpha=self.ngram_lm_alpha,
373374
beta=self.beam_beta,
374375
num_cpus=max(1, os.cpu_count()),
375376
input_tensor=False,
@@ -451,7 +452,7 @@ def _pyctcdecode_beam_search(
451452

452453
if self.pyctcdecode_beam_scorer is None:
453454
self.pyctcdecode_beam_scorer = pyctcdecode.build_ctcdecoder(
454-
labels=self.vocab, kenlm_model_path=self.kenlm_path, alpha=self.beam_alpha, beta=self.beam_beta
455+
labels=self.vocab, kenlm_model_path=self.ngram_lm_model, alpha=self.ngram_lm_alpha, beta=self.beam_beta
455456
) # type: pyctcdecode.BeamSearchDecoderCTC
456457

457458
x = x.to('cpu').numpy()
@@ -533,9 +534,9 @@ def flashlight_beam_search(
533534

534535
if self.flashlight_beam_scorer is None:
535536
# Check for filepath
536-
if self.kenlm_path is None or not os.path.exists(self.kenlm_path):
537+
if self.ngram_lm_model is None or not os.path.exists(self.ngram_lm_model):
537538
raise FileNotFoundError(
538-
f"KenLM binary file not found at : {self.kenlm_path}. "
539+
f"KenLM binary file not found at : {self.ngram_lm_model}. "
539540
"Please set a valid path in the decoding config."
540541
)
541542

@@ -550,15 +551,15 @@ def flashlight_beam_search(
550551
from nemo.collections.asr.modules.flashlight_decoder import FlashLightKenLMBeamSearchDecoder
551552

552553
self.flashlight_beam_scorer = FlashLightKenLMBeamSearchDecoder(
553-
lm_path=self.kenlm_path,
554+
lm_path=self.ngram_lm_model,
554555
vocabulary=self.vocab,
555556
tokenizer=self.tokenizer,
556557
lexicon_path=self.flashlight_cfg.lexicon_path,
557558
boost_path=self.flashlight_cfg.boost_path,
558559
beam_size=self.beam_size,
559560
beam_size_token=self.flashlight_cfg.beam_size_token,
560561
beam_threshold=self.flashlight_cfg.beam_threshold,
561-
lm_weight=self.beam_alpha,
562+
lm_weight=self.ngram_lm_alpha,
562563
word_score=self.beam_beta,
563564
unk_weight=self.flashlight_cfg.unk_weight,
564565
sil_weight=self.flashlight_cfg.sil_weight,
@@ -877,6 +878,108 @@ def _k2_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbes
877878
return self.k2_decoder.decode(x.to(device=self.device), out_len.to(device=self.device))
878879

879880

881+
class BeamBatchedCTCInfer(AbstractBeamCTCInfer):
882+
"""
883+
A batched beam CTC decoder.
884+
885+
Args:
886+
blank_index: int index of the blank token. Can be 0 or len(vocabulary).
887+
beam_size: int size of the beam.
888+
return_best_hypothesis: When set to True (default), returns a single Hypothesis.
889+
When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis.
890+
preserve_alignments: Bool flag which preserves the history of logprobs generated during
891+
decoding (sample / batched). When set to true, the Hypothesis will contain
892+
the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.
893+
compute_timestamps: A bool flag, which determines whether to compute the character/subword, or
894+
word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
895+
The timestamps will be available in the returned Hypothesis.timestep as a dictionary.
896+
ngram_lm_alpha: float, the language model weight.
897+
beam_beta: float, the word insertion weight.
898+
beam_threshold: float, the beam pruning threshold.
899+
ngram_lm_model: str, the path to the ngram model.
900+
allow_cuda_graphs: bool, whether to allow cuda graphs for the beam search algorithm.
901+
"""
902+
903+
def __init__(
904+
self,
905+
blank_index: int,
906+
beam_size: int,
907+
return_best_hypothesis: bool = True,
908+
preserve_alignments: bool = False,
909+
compute_timestamps: bool = False,
910+
ngram_lm_alpha: float = 1.0,
911+
beam_beta: float = 0.0,
912+
beam_threshold: float = 20.0,
913+
ngram_lm_model: str = None,
914+
allow_cuda_graphs: bool = True,
915+
):
916+
super().__init__(blank_id=blank_index, beam_size=beam_size)
917+
918+
self.return_best_hypothesis = return_best_hypothesis
919+
self.preserve_alignments = preserve_alignments
920+
self.compute_timestamps = compute_timestamps
921+
self.allow_cuda_graphs = allow_cuda_graphs
922+
923+
if self.compute_timestamps:
924+
raise ValueError("`Compute timestamps` is not supported for batched beam search.")
925+
if self.preserve_alignments:
926+
raise ValueError("`Preserve alignments` is not supported for batched beam search.")
927+
928+
self.ngram_lm_alpha = ngram_lm_alpha
929+
self.beam_beta = beam_beta
930+
self.beam_threshold = beam_threshold
931+
932+
# Default beam search args
933+
self.ngram_lm_model = ngram_lm_model
934+
935+
self.search_algorithm = BatchedBeamCTCComputer(
936+
blank_index=blank_index,
937+
beam_size=beam_size,
938+
return_best_hypothesis=return_best_hypothesis,
939+
preserve_alignments=preserve_alignments,
940+
compute_timestamps=compute_timestamps,
941+
ngram_lm_alpha=ngram_lm_alpha,
942+
beam_beta=beam_beta,
943+
beam_threshold=beam_threshold,
944+
ngram_lm_model=ngram_lm_model,
945+
allow_cuda_graphs=allow_cuda_graphs,
946+
)
947+
948+
@typecheck()
949+
def forward(
950+
self,
951+
decoder_output: torch.Tensor,
952+
decoder_lengths: torch.Tensor,
953+
) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]:
954+
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
955+
Output token is generated auto-repressively.
956+
957+
Args:
958+
decoder_output: A tensor of size (batch, timesteps, features).
959+
decoder_lengths: list of int representing the length of each sequence
960+
output sequence.
961+
962+
Returns:
963+
packed list containing batch number of sentences (Hypotheses).
964+
"""
965+
with torch.no_grad(), torch.inference_mode():
966+
if decoder_output.ndim != 3:
967+
raise ValueError(
968+
f"`decoder_output` must be a tensor of shape [B, T, V] (log probs, float). "
969+
f"Provided shape = {decoder_output.shape}"
970+
)
971+
972+
batched_beam_hyps = self.search_algorithm(decoder_output, decoder_lengths)
973+
974+
batch_size = decoder_lengths.shape[0]
975+
if self.return_best_hypothesis:
976+
hyps = batched_beam_hyps.to_hyps_list(score_norm=False)[:batch_size]
977+
else:
978+
hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=False)[:batch_size]
979+
980+
return (hyps,)
981+
982+
880983
@dataclass
881984
class PyCTCDecodeConfig:
882985
# These arguments cannot be imported from pyctcdecode (optional dependency)
@@ -906,10 +1009,14 @@ class BeamCTCInferConfig:
9061009
preserve_alignments: bool = False
9071010
compute_timestamps: bool = False
9081011
return_best_hypothesis: bool = True
1012+
allow_cuda_graphs: bool = True
9091013

910-
beam_alpha: float = 1.0
911-
beam_beta: float = 0.0
912-
kenlm_path: Optional[str] = None
1014+
beam_alpha: Optional[float] = None # Deprecated
1015+
beam_beta: float = 1.0
1016+
beam_threshold: float = 20.0
1017+
kenlm_path: Optional[str] = None # Deprecated, default should be None
1018+
ngram_lm_alpha: Optional[float] = 1.0
1019+
ngram_lm_model: Optional[str] = None
9131020

9141021
flashlight_cfg: Optional[FlashlightConfig] = field(default_factory=lambda: FlashlightConfig())
9151022
pyctcdecode_cfg: Optional[PyCTCDecodeConfig] = field(default_factory=lambda: PyCTCDecodeConfig())

nemo/collections/asr/parts/submodules/ctc_decoding.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,15 @@ class AbstractCTCDecoding(ConfidenceMixin):
178178
optional bool, whether to return just the best hypothesis or all of the
179179
hypotheses after beam search has concluded. This flag is set by default.
180180
181-
beam_alpha:
181+
ngram_lm_alpha:
182182
float, the strength of the Language model on the final score of a token.
183-
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
183+
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.
184184
185185
beam_beta:
186186
float, the strength of the sequence length penalty on the final score of a token.
187-
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
187+
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.
188188
189-
kenlm_path:
189+
ngram_lm_model:
190190
str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
191191
If the path is invalid (file is not found at path), will raise a deferred error at the moment
192192
of calculation of beam search, so that users may update / change the decoding strategy
@@ -226,7 +226,7 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
226226
self.segment_seperators = self.cfg.get('segment_seperators', ['.', '?', '!'])
227227
self.segment_gap_threshold = self.cfg.get('segment_gap_threshold', None)
228228

229-
possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst']
229+
possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst', 'beam_batch']
230230
if self.cfg.strategy not in possible_strategies:
231231
raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}")
232232

@@ -267,6 +267,20 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
267267
if self.compute_timestamps is not None:
268268
self.compute_timestamps |= self.preserve_frame_confidence
269269

270+
if self.cfg.strategy in ['flashlight', 'wfst', 'beam_batch', 'pyctcdecode', 'beam']:
271+
if self.cfg.beam.beam_alpha is not None:
272+
logging.warning(
273+
"`beam_alpha` is deprecated and will be removed in a future release. "
274+
"Please use `ngram_lm_alpha` instead."
275+
)
276+
self.cfg.beam.ngram_lm_alpha = self.cfg.beam.beam_alpha
277+
if self.cfg.beam.kenlm_path is not None:
278+
logging.warning(
279+
"`kenlm_path` is deprecated and will be removed in a future release. "
280+
"Please use `ngram_lm_model` instead."
281+
)
282+
self.cfg.beam.ngram_lm_model = self.cfg.beam.kenlm_path
283+
270284
if self.cfg.strategy == 'greedy':
271285
self.decoding = ctc_greedy_decoding.GreedyCTCInfer(
272286
blank_id=self.blank_id,
@@ -294,9 +308,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
294308
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
295309
preserve_alignments=self.preserve_alignments,
296310
compute_timestamps=self.compute_timestamps,
297-
beam_alpha=self.cfg.beam.get('beam_alpha', 1.0),
311+
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
298312
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
299-
kenlm_path=self.cfg.beam.get('kenlm_path', None),
313+
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
300314
)
301315

302316
self.decoding.override_fold_consecutive_value = False
@@ -310,9 +324,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
310324
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
311325
preserve_alignments=self.preserve_alignments,
312326
compute_timestamps=self.compute_timestamps,
313-
beam_alpha=self.cfg.beam.get('beam_alpha', 1.0),
327+
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
314328
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
315-
kenlm_path=self.cfg.beam.get('kenlm_path', None),
329+
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
316330
pyctcdecode_cfg=self.cfg.beam.get('pyctcdecode_cfg', None),
317331
)
318332

@@ -327,9 +341,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
327341
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
328342
preserve_alignments=self.preserve_alignments,
329343
compute_timestamps=self.compute_timestamps,
330-
beam_alpha=self.cfg.beam.get('beam_alpha', 1.0),
344+
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
331345
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
332-
kenlm_path=self.cfg.beam.get('kenlm_path', None),
346+
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
333347
flashlight_cfg=self.cfg.beam.get('flashlight_cfg', None),
334348
)
335349

@@ -357,6 +371,22 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
357371

358372
self.decoding.override_fold_consecutive_value = False
359373

374+
elif self.cfg.strategy == 'beam_batch':
375+
self.decoding = ctc_beam_decoding.BeamBatchedCTCInfer(
376+
blank_index=blank_id,
377+
beam_size=self.cfg.beam.get('beam_size', 1),
378+
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
379+
preserve_alignments=self.preserve_alignments,
380+
compute_timestamps=self.compute_timestamps,
381+
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
382+
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
383+
beam_threshold=self.cfg.beam.get('beam_threshold', 20.0),
384+
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
385+
allow_cuda_graphs=self.cfg.beam.get('allow_cuda_graphs', True),
386+
)
387+
388+
self.decoding.override_fold_consecutive_value = False
389+
360390
else:
361391
raise ValueError(
362392
f"Incorrect decoding strategy supplied. Must be one of {possible_strategies}\n"
@@ -1051,15 +1081,15 @@ class CTCDecoding(AbstractCTCDecoding):
10511081
optional bool, whether to return just the best hypothesis or all of the
10521082
hypotheses after beam search has concluded. This flag is set by default.
10531083
1054-
beam_alpha:
1084+
ngram_lm_alpha:
10551085
float, the strength of the Language model on the final score of a token.
1056-
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
1086+
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.
10571087
10581088
beam_beta:
10591089
float, the strength of the sequence length penalty on the final score of a token.
1060-
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
1090+
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.
10611091
1062-
kenlm_path:
1092+
ngram_lm_model:
10631093
str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
10641094
If the path is invalid (file is not found at path), will raise a deferred error at the moment
10651095
of calculation of beam search, so that users may update / change the decoding strategy
@@ -1340,15 +1370,15 @@ class CTCBPEDecoding(AbstractCTCDecoding):
13401370
optional bool, whether to return just the best hypothesis or all of the
13411371
hypotheses after beam search has concluded. This flag is set by default.
13421372
1343-
beam_alpha:
1373+
ngram_lm_alpha:
13441374
float, the strength of the Language model on the final score of a token.
1345-
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
1375+
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.
13461376
13471377
beam_beta:
13481378
float, the strength of the sequence length penalty on the final score of a token.
1349-
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
1379+
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.
13501380
1351-
kenlm_path:
1381+
ngram_lm_model:
13521382
str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
13531383
If the path is invalid (file is not found at path), will raise a deferred error at the moment
13541384
of calculation of beam search, so that users may update / change the decoding strategy

nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer
4141
from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
4242
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
43-
from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import BlankLMScoreMode, PruningMode
43+
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BlankLMScoreMode, PruningMode
4444
from nemo.collections.asr.parts.utils.rnnt_utils import (
4545
HATJointOutput,
4646
Hypothesis,

0 commit comments

Comments
 (0)