2222import torch
2323
2424from nemo .collections .asr .parts .k2 .classes import GraphIntersectDenseConfig
25+ from nemo .collections .asr .parts .submodules .ctc_batched_beam_decoding import BatchedBeamCTCComputer
2526from nemo .collections .asr .parts .submodules .ngram_lm import DEFAULT_TOKEN_OFFSET
2627from nemo .collections .asr .parts .submodules .wfst_decoder import RivaDecoderConfig , WfstNbestHypothesis
2728from nemo .collections .asr .parts .utils import rnnt_utils
@@ -204,7 +205,7 @@ def __call__(self, *args, **kwargs):
204205
205206
206207class BeamCTCInfer (AbstractBeamCTCInfer ):
207- """A greedy CTC decoder.
208+ """A beam CTC decoder.
208209
209210 Provides a common abstraction for sample level and batch level greedy decoding.
210211
@@ -227,9 +228,9 @@ def __init__(
227228 return_best_hypothesis : bool = True ,
228229 preserve_alignments : bool = False ,
229230 compute_timestamps : bool = False ,
230- beam_alpha : float = 1.0 ,
231+ ngram_lm_alpha : float = 0.3 ,
231232 beam_beta : float = 0.0 ,
232- kenlm_path : str = None ,
233+ ngram_lm_model : str = None ,
233234 flashlight_cfg : Optional ['FlashlightConfig' ] = None ,
234235 pyctcdecode_cfg : Optional ['PyCTCDecodeConfig' ] = None ,
235236 ):
@@ -260,11 +261,11 @@ def __init__(
260261 # Log the beam search algorithm
261262 logging .info (f"Beam search algorithm: { search_type } " )
262263
263- self .beam_alpha = beam_alpha
264+ self .ngram_lm_alpha = ngram_lm_alpha
264265 self .beam_beta = beam_beta
265266
266267 # Default beam search args
267- self .kenlm_path = kenlm_path
268+ self .ngram_lm_model = ngram_lm_model
268269
269270 # PyCTCDecode params
270271 if pyctcdecode_cfg is None :
@@ -349,9 +350,9 @@ def default_beam_search(
349350
350351 if self .default_beam_scorer is None :
351352 # Check for filepath
352- if self .kenlm_path is None or not os .path .exists (self .kenlm_path ):
353+ if self .ngram_lm_model is None or not os .path .exists (self .ngram_lm_model ):
353354 raise FileNotFoundError (
354- f"KenLM binary file not found at : { self .kenlm_path } . "
355+ f"KenLM binary file not found at : { self .ngram_lm_model } . "
355356 f"Please set a valid path in the decoding config."
356357 )
357358
@@ -367,9 +368,9 @@ def default_beam_search(
367368
368369 self .default_beam_scorer = BeamSearchDecoderWithLM (
369370 vocab = vocab ,
370- lm_path = self .kenlm_path ,
371+ lm_path = self .ngram_lm_model ,
371372 beam_width = self .beam_size ,
372- alpha = self .beam_alpha ,
373+ alpha = self .ngram_lm_alpha ,
373374 beta = self .beam_beta ,
374375 num_cpus = max (1 , os .cpu_count ()),
375376 input_tensor = False ,
@@ -451,7 +452,7 @@ def _pyctcdecode_beam_search(
451452
452453 if self .pyctcdecode_beam_scorer is None :
453454 self .pyctcdecode_beam_scorer = pyctcdecode .build_ctcdecoder (
454- labels = self .vocab , kenlm_model_path = self .kenlm_path , alpha = self .beam_alpha , beta = self .beam_beta
455+ labels = self .vocab , kenlm_model_path = self .ngram_lm_model , alpha = self .ngram_lm_alpha , beta = self .beam_beta
455456 ) # type: pyctcdecode.BeamSearchDecoderCTC
456457
457458 x = x .to ('cpu' ).numpy ()
@@ -533,9 +534,9 @@ def flashlight_beam_search(
533534
534535 if self .flashlight_beam_scorer is None :
535536 # Check for filepath
536- if self .kenlm_path is None or not os .path .exists (self .kenlm_path ):
537+ if self .ngram_lm_model is None or not os .path .exists (self .ngram_lm_model ):
537538 raise FileNotFoundError (
538- f"KenLM binary file not found at : { self .kenlm_path } . "
539+ f"KenLM binary file not found at : { self .ngram_lm_model } . "
539540 "Please set a valid path in the decoding config."
540541 )
541542
@@ -550,15 +551,15 @@ def flashlight_beam_search(
550551 from nemo .collections .asr .modules .flashlight_decoder import FlashLightKenLMBeamSearchDecoder
551552
552553 self .flashlight_beam_scorer = FlashLightKenLMBeamSearchDecoder (
553- lm_path = self .kenlm_path ,
554+ lm_path = self .ngram_lm_model ,
554555 vocabulary = self .vocab ,
555556 tokenizer = self .tokenizer ,
556557 lexicon_path = self .flashlight_cfg .lexicon_path ,
557558 boost_path = self .flashlight_cfg .boost_path ,
558559 beam_size = self .beam_size ,
559560 beam_size_token = self .flashlight_cfg .beam_size_token ,
560561 beam_threshold = self .flashlight_cfg .beam_threshold ,
561- lm_weight = self .beam_alpha ,
562+ lm_weight = self .ngram_lm_alpha ,
562563 word_score = self .beam_beta ,
563564 unk_weight = self .flashlight_cfg .unk_weight ,
564565 sil_weight = self .flashlight_cfg .sil_weight ,
@@ -877,6 +878,108 @@ def _k2_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbes
877878 return self .k2_decoder .decode (x .to (device = self .device ), out_len .to (device = self .device ))
878879
879880
881+ class BeamBatchedCTCInfer (AbstractBeamCTCInfer ):
882+ """
883+ A batched beam CTC decoder.
884+
885+ Args:
886+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
887+ beam_size: int size of the beam.
888+ return_best_hypothesis: When set to True (default), returns a single Hypothesis.
889+ When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis.
890+ preserve_alignments: Bool flag which preserves the history of logprobs generated during
891+ decoding (sample / batched). When set to true, the Hypothesis will contain
892+ the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.
893+ compute_timestamps: A bool flag, which determines whether to compute the character/subword, or
894+ word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
895+ The timestamps will be available in the returned Hypothesis.timestep as a dictionary.
896+ ngram_lm_alpha: float, the language model weight.
897+ beam_beta: float, the word insertion weight.
898+ beam_threshold: float, the beam pruning threshold.
899+ ngram_lm_model: str, the path to the ngram model.
900+ allow_cuda_graphs: bool, whether to allow cuda graphs for the beam search algorithm.
901+ """
902+
903+ def __init__ (
904+ self ,
905+ blank_index : int ,
906+ beam_size : int ,
907+ return_best_hypothesis : bool = True ,
908+ preserve_alignments : bool = False ,
909+ compute_timestamps : bool = False ,
910+ ngram_lm_alpha : float = 1.0 ,
911+ beam_beta : float = 0.0 ,
912+ beam_threshold : float = 20.0 ,
913+ ngram_lm_model : str = None ,
914+ allow_cuda_graphs : bool = True ,
915+ ):
916+ super ().__init__ (blank_id = blank_index , beam_size = beam_size )
917+
918+ self .return_best_hypothesis = return_best_hypothesis
919+ self .preserve_alignments = preserve_alignments
920+ self .compute_timestamps = compute_timestamps
921+ self .allow_cuda_graphs = allow_cuda_graphs
922+
923+ if self .compute_timestamps :
924+ raise ValueError ("`Compute timestamps` is not supported for batched beam search." )
925+ if self .preserve_alignments :
926+ raise ValueError ("`Preserve alignments` is not supported for batched beam search." )
927+
928+ self .ngram_lm_alpha = ngram_lm_alpha
929+ self .beam_beta = beam_beta
930+ self .beam_threshold = beam_threshold
931+
932+ # Default beam search args
933+ self .ngram_lm_model = ngram_lm_model
934+
935+ self .search_algorithm = BatchedBeamCTCComputer (
936+ blank_index = blank_index ,
937+ beam_size = beam_size ,
938+ return_best_hypothesis = return_best_hypothesis ,
939+ preserve_alignments = preserve_alignments ,
940+ compute_timestamps = compute_timestamps ,
941+ ngram_lm_alpha = ngram_lm_alpha ,
942+ beam_beta = beam_beta ,
943+ beam_threshold = beam_threshold ,
944+ ngram_lm_model = ngram_lm_model ,
945+ allow_cuda_graphs = allow_cuda_graphs ,
946+ )
947+
948+ @typecheck ()
949+ def forward (
950+ self ,
951+ decoder_output : torch .Tensor ,
952+ decoder_lengths : torch .Tensor ,
953+ ) -> Tuple [List [Union [rnnt_utils .Hypothesis , rnnt_utils .NBestHypotheses ]]]:
954+ """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
955+ Output token is generated auto-repressively.
956+
957+ Args:
958+ decoder_output: A tensor of size (batch, timesteps, features).
959+ decoder_lengths: list of int representing the length of each sequence
960+ output sequence.
961+
962+ Returns:
963+ packed list containing batch number of sentences (Hypotheses).
964+ """
965+ with torch .no_grad (), torch .inference_mode ():
966+ if decoder_output .ndim != 3 :
967+ raise ValueError (
968+ f"`decoder_output` must be a tensor of shape [B, T, V] (log probs, float). "
969+ f"Provided shape = { decoder_output .shape } "
970+ )
971+
972+ batched_beam_hyps = self .search_algorithm (decoder_output , decoder_lengths )
973+
974+ batch_size = decoder_lengths .shape [0 ]
975+ if self .return_best_hypothesis :
976+ hyps = batched_beam_hyps .to_hyps_list (score_norm = False )[:batch_size ]
977+ else :
978+ hyps = batched_beam_hyps .to_nbest_hyps_list (score_norm = False )[:batch_size ]
979+
980+ return (hyps ,)
981+
982+
880983@dataclass
881984class PyCTCDecodeConfig :
882985 # These arguments cannot be imported from pyctcdecode (optional dependency)
@@ -906,10 +1009,14 @@ class BeamCTCInferConfig:
9061009 preserve_alignments : bool = False
9071010 compute_timestamps : bool = False
9081011 return_best_hypothesis : bool = True
1012+ allow_cuda_graphs : bool = True
9091013
910- beam_alpha : float = 1.0
911- beam_beta : float = 0.0
912- kenlm_path : Optional [str ] = None
1014+ beam_alpha : Optional [float ] = None # Deprecated
1015+ beam_beta : float = 1.0
1016+ beam_threshold : float = 20.0
1017+ kenlm_path : Optional [str ] = None # Deprecated, default should be None
1018+ ngram_lm_alpha : Optional [float ] = 1.0
1019+ ngram_lm_model : Optional [str ] = None
9131020
9141021 flashlight_cfg : Optional [FlashlightConfig ] = field (default_factory = lambda : FlashlightConfig ())
9151022 pyctcdecode_cfg : Optional [PyCTCDecodeConfig ] = field (default_factory = lambda : PyCTCDecodeConfig ())
0 commit comments