Skip to content
Merged
Show file tree
Hide file tree
Changes from 91 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
8ead3f3
add ctc beam decoding
lilithgrigoryan Mar 24, 2025
8f8be47
add utils
lilithgrigoryan Mar 24, 2025
91e48c2
first working
lilithgrigoryan Mar 24, 2025
df1ac04
working cuda graphs
lilithgrigoryan Mar 27, 2025
f2ee3dc
fix bugs with cudagraohs
lilithgrigoryan Mar 31, 2025
83ecc97
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Mar 31, 2025
f50390e
working
lilithgrigoryan Mar 31, 2025
70cc5ca
small fix
lilithgrigoryan Apr 1, 2025
636c219
minor fix
lilithgrigoryan Apr 16, 2025
1733050
add logging
lilithgrigoryan Apr 25, 2025
9a4248f
add print
lilithgrigoryan Apr 27, 2025
9103945
to log sum exp
lilithgrigoryan Apr 28, 2025
9094e56
back to max score
lilithgrigoryan Apr 28, 2025
5b6bdac
fix bug in cudagraphs, save before refactor
lilithgrigoryan Apr 29, 2025
5339d82
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Apr 29, 2025
f944930
rm log10
lilithgrigoryan Apr 29, 2025
939d105
rm prints
lilithgrigoryan Apr 29, 2025
1de07f2
add reallocation
lilithgrigoryan Apr 29, 2025
af1277f
rm logprobs from state
lilithgrigoryan Apr 29, 2025
1b5a223
rm nexts from state
lilithgrigoryan Apr 29, 2025
6ded71a
rm prev lm states
lilithgrigoryan Apr 29, 2025
44c1ca6
small clean up
lilithgrigoryan Apr 29, 2025
963a257
clean up cuda graphs
lilithgrigoryan Apr 29, 2025
9df22a1
cudagraph working
lilithgrigoryan Apr 29, 2025
676913b
clean up torch working
lilithgrigoryan Apr 29, 2025
7071328
Apply isort and black reformatting
lilithgrigoryan Apr 29, 2025
1183403
rm files
lilithgrigoryan Apr 29, 2025
781c9d2
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Apr 29, 2025
d72f6f3
save
lilithgrigoryan Apr 30, 2025
457806d
add flatten
lilithgrigoryan Apr 30, 2025
83acec6
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Apr 30, 2025
d047c1f
Apply isort and black reformatting
lilithgrigoryan Apr 30, 2025
4a87def
clean up
lilithgrigoryan Apr 30, 2025
b01a08b
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Apr 30, 2025
b0bccdf
clean up
lilithgrigoryan Apr 30, 2025
181e7e0
add timestamps
lilithgrigoryan Apr 30, 2025
141430f
rm file
lilithgrigoryan Apr 30, 2025
cb4eef9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Apr 30, 2025
f370560
Apply isort and black reformatting
lilithgrigoryan Apr 30, 2025
ebeb866
rename file
lilithgrigoryan Apr 30, 2025
8916e48
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Apr 30, 2025
d281a3f
add batched beam tests
lilithgrigoryan Apr 30, 2025
556850f
Apply isort and black reformatting
lilithgrigoryan Apr 30, 2025
f007390
add tests
lilithgrigoryan Apr 30, 2025
057c884
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Apr 30, 2025
9849c1f
Apply isort and black reformatting
lilithgrigoryan Apr 30, 2025
2bd7344
changed return type
lilithgrigoryan Apr 30, 2025
2bba9ab
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Apr 30, 2025
5d89c3a
clean up
lilithgrigoryan Apr 30, 2025
7dcec45
Apply isort and black reformatting
lilithgrigoryan Apr 30, 2025
d429ada
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Apr 30, 2025
eb7c350
minor changes
lilithgrigoryan Apr 30, 2025
6666816
clean up
lilithgrigoryan May 1, 2025
d3bb22d
clean up
lilithgrigoryan May 1, 2025
da40efb
Merge branch 'main' into lgrigoryan/ctc_beam_search_pr
ko3n1g May 2, 2025
f5af89f
renamed variables
lilithgrigoryan May 13, 2025
cd978f2
changed is_tdt to model_type
lilithgrigoryan May 13, 2025
c890e9c
unified batched beam hyps
lilithgrigoryan May 13, 2025
59d295a
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/vbataev/n…
lilithgrigoryan May 13, 2025
3ff6a0d
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan May 13, 2025
8669582
Apply isort and black reformatting
lilithgrigoryan May 13, 2025
31a82a0
clean up
lilithgrigoryan May 13, 2025
b32f63f
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan May 13, 2025
1c62c5a
clean up
lilithgrigoryan May 13, 2025
971fc16
clean up
lilithgrigoryan May 13, 2025
8c7a306
clean up
lilithgrigoryan May 13, 2025
dd155f2
clean up
lilithgrigoryan May 13, 2025
f89fd03
clean up
lilithgrigoryan May 13, 2025
0615c25
clean up
lilithgrigoryan May 14, 2025
b5cfa6d
Update cuda_program_name
lilithgrigoryan May 14, 2025
e22541f
clean up and and commments
lilithgrigoryan May 14, 2025
5134e8f
clean up and small fixes
lilithgrigoryan May 14, 2025
ab9a4c3
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan May 14, 2025
20b705c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan May 14, 2025
e198787
Apply isort and black reformatting
lilithgrigoryan May 14, 2025
f6258ae
fix
lilithgrigoryan May 14, 2025
b6997c9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan May 14, 2025
86e89f1
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan May 14, 2025
a56a7e7
fix tests
lilithgrigoryan May 19, 2025
de8eba7
added check on model type
lilithgrigoryan May 19, 2025
70b6b82
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan May 19, 2025
6de0b25
minor change
lilithgrigoryan May 29, 2025
f4b3983
rm repetitions LM scoring
lilithgrigoryan May 29, 2025
ed59213
add enum model type
lilithgrigoryan May 29, 2025
66b115a
add enum model type
lilithgrigoryan May 29, 2025
80a409e
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan May 29, 2025
a55d688
Apply isort and black reformatting
lilithgrigoryan May 29, 2025
943fb7e
fix lm repetitions for cudahraphs
lilithgrigoryan Jun 3, 2025
0642c80
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Jun 3, 2025
763a03e
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Jun 3, 2025
cf95ecb
Apply isort and black reformatting
lilithgrigoryan Jun 3, 2025
f7f7c68
clean up
lilithgrigoryan Jun 9, 2025
7e2eb60
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Jun 9, 2025
335a90d
Merge branch 'lgrigoryan/ctc_beam_search_pr' of https://github.com/NV…
lilithgrigoryan Jun 9, 2025
38e2321
clean up
lilithgrigoryan Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions audio
Submodule audio added at 95c61b
696 changes: 696 additions & 0 deletions nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py

Large diffs are not rendered by default.

149 changes: 132 additions & 17 deletions nemo/collections/asr/parts/submodules/ctc_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig
from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer
from nemo.collections.asr.parts.submodules.ngram_lm import DEFAULT_TOKEN_OFFSET
from nemo.collections.asr.parts.submodules.wfst_decoder import RivaDecoderConfig, WfstNbestHypothesis
from nemo.collections.asr.parts.utils import rnnt_utils
Expand Down Expand Up @@ -204,7 +205,7 @@ def __call__(self, *args, **kwargs):


class BeamCTCInfer(AbstractBeamCTCInfer):
"""A greedy CTC decoder.
"""A beam CTC decoder.

Provides a common abstraction for sample level and batch level greedy decoding.

Expand All @@ -227,9 +228,9 @@ def __init__(
return_best_hypothesis: bool = True,
preserve_alignments: bool = False,
compute_timestamps: bool = False,
beam_alpha: float = 1.0,
ngram_lm_alpha: float = 0.3,
beam_beta: float = 0.0,
kenlm_path: str = None,
ngram_lm_model: str = None,
flashlight_cfg: Optional['FlashlightConfig'] = None,
pyctcdecode_cfg: Optional['PyCTCDecodeConfig'] = None,
):
Expand Down Expand Up @@ -260,11 +261,11 @@ def __init__(
# Log the beam search algorithm
logging.info(f"Beam search algorithm: {search_type}")

self.beam_alpha = beam_alpha
self.ngram_lm_alpha = ngram_lm_alpha
self.beam_beta = beam_beta

# Default beam search args
self.kenlm_path = kenlm_path
self.ngram_lm_model = ngram_lm_model

# PyCTCDecode params
if pyctcdecode_cfg is None:
Expand Down Expand Up @@ -349,9 +350,9 @@ def default_beam_search(

if self.default_beam_scorer is None:
# Check for filepath
if self.kenlm_path is None or not os.path.exists(self.kenlm_path):
if self.ngram_lm_model is None or not os.path.exists(self.ngram_lm_model):
raise FileNotFoundError(
f"KenLM binary file not found at : {self.kenlm_path}. "
f"KenLM binary file not found at : {self.ngram_lm_model}. "
f"Please set a valid path in the decoding config."
)

Expand All @@ -367,9 +368,9 @@ def default_beam_search(

self.default_beam_scorer = BeamSearchDecoderWithLM(
vocab=vocab,
lm_path=self.kenlm_path,
lm_path=self.ngram_lm_model,
beam_width=self.beam_size,
alpha=self.beam_alpha,
alpha=self.ngram_lm_alpha,
beta=self.beam_beta,
num_cpus=max(1, os.cpu_count()),
input_tensor=False,
Expand Down Expand Up @@ -451,7 +452,7 @@ def _pyctcdecode_beam_search(

if self.pyctcdecode_beam_scorer is None:
self.pyctcdecode_beam_scorer = pyctcdecode.build_ctcdecoder(
labels=self.vocab, kenlm_model_path=self.kenlm_path, alpha=self.beam_alpha, beta=self.beam_beta
labels=self.vocab, kenlm_model_path=self.ngram_lm_model, alpha=self.ngram_lm_alpha, beta=self.beam_beta
) # type: pyctcdecode.BeamSearchDecoderCTC

x = x.to('cpu').numpy()
Expand Down Expand Up @@ -533,9 +534,9 @@ def flashlight_beam_search(

if self.flashlight_beam_scorer is None:
# Check for filepath
if self.kenlm_path is None or not os.path.exists(self.kenlm_path):
if self.ngram_lm_model is None or not os.path.exists(self.ngram_lm_model):
raise FileNotFoundError(
f"KenLM binary file not found at : {self.kenlm_path}. "
f"KenLM binary file not found at : {self.ngram_lm_model}. "
"Please set a valid path in the decoding config."
)

Expand All @@ -550,15 +551,15 @@ def flashlight_beam_search(
from nemo.collections.asr.modules.flashlight_decoder import FlashLightKenLMBeamSearchDecoder

self.flashlight_beam_scorer = FlashLightKenLMBeamSearchDecoder(
lm_path=self.kenlm_path,
lm_path=self.ngram_lm_model,
vocabulary=self.vocab,
tokenizer=self.tokenizer,
lexicon_path=self.flashlight_cfg.lexicon_path,
boost_path=self.flashlight_cfg.boost_path,
beam_size=self.beam_size,
beam_size_token=self.flashlight_cfg.beam_size_token,
beam_threshold=self.flashlight_cfg.beam_threshold,
lm_weight=self.beam_alpha,
lm_weight=self.ngram_lm_alpha,
word_score=self.beam_beta,
unk_weight=self.flashlight_cfg.unk_weight,
sil_weight=self.flashlight_cfg.sil_weight,
Expand Down Expand Up @@ -877,6 +878,116 @@ def _k2_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbes
return self.k2_decoder.decode(x.to(device=self.device), out_len.to(device=self.device))


class BeamBatchedCTCInfer(AbstractBeamCTCInfer):
"""
A batched beam CTC decoder.

Args:
blank_index: int index of the blank token. Can be 0 or len(vocabulary).
beam_size: int size of the beam.
return_best_hypothesis: When set to True (default), returns a single Hypothesis.
When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis.
preserve_alignments: Bool flag which preserves the history of logprobs generated during
decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.
compute_timestamps: A bool flag, which determines whether to compute the character/subword, or
word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
The timestamps will be available in the returned Hypothesis.timestep as a dictionary.
ngram_lm_alpha: float, the language model weight.
beam_beta: float, the word insertion weight.
beam_threshold: float, the beam pruning threshold.
ngram_lm_model: str, the path to the ngram model.
allow_cuda_graphs: bool, whether to allow cuda graphs for the beam search algorithm.
"""

def __init__(
self,
blank_index: int,
beam_size: int,
return_best_hypothesis: bool = True,
preserve_alignments: bool = False,
compute_timestamps: bool = False,
ngram_lm_alpha: float = 1.0,
beam_beta: float = 0.0,
beam_threshold: float = 20.0,
ngram_lm_model: str = None,
allow_cuda_graphs: bool = True,
):
super().__init__(blank_id=blank_index, beam_size=beam_size)

self.return_best_hypothesis = return_best_hypothesis
self.preserve_alignments = preserve_alignments
self.compute_timestamps = compute_timestamps
self.allow_cuda_graphs = allow_cuda_graphs

if self.compute_timestamps:
raise ValueError("`Compute timestamps` is not supported for batched beam search.")
if self.preserve_alignments:
raise ValueError("`Preserve alignments` is not supported for batched beam search.")

self.vocab = None # This must be set by specific method by user before calling forward() !

self.ngram_lm_alpha = ngram_lm_alpha
self.beam_beta = beam_beta
self.beam_threshold = beam_threshold

# Default beam search args
self.ngram_lm_model = ngram_lm_model

# Default beam search scorer functions
self.default_beam_scorer = None
self.pyctcdecode_beam_scorer = None
self.flashlight_beam_scorer = None
self.token_offset = 0

self.search_algorithm = BatchedBeamCTCComputer(
blank_index=blank_index,
beam_size=beam_size,
return_best_hypothesis=return_best_hypothesis,
preserve_alignments=preserve_alignments,
compute_timestamps=compute_timestamps,
ngram_lm_alpha=ngram_lm_alpha,
beam_beta=beam_beta,
beam_threshold=beam_threshold,
ngram_lm_model=ngram_lm_model,
allow_cuda_graphs=allow_cuda_graphs,
)

@typecheck()
def forward(
self,
decoder_output: torch.Tensor,
decoder_lengths: torch.Tensor,
) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]:
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.

Args:
decoder_output: A tensor of size (batch, timesteps, features).
decoder_lengths: list of int representing the length of each sequence
output sequence.

Returns:
packed list containing batch number of sentences (Hypotheses).
"""
with torch.no_grad(), torch.inference_mode():
if decoder_output.ndim != 3:
raise ValueError(
f"`decoder_output` must be a tensor of shape [B, T, V] (log probs, float). "
f"Provided shape = {decoder_output.shape}"
)

batched_beam_hyps = self.search_algorithm(decoder_output, decoder_lengths)

batch_size = decoder_lengths.shape[0]
if self.return_best_hypothesis:
hyps = batched_beam_hyps.to_hyps_list(score_norm=False)[:batch_size]
else:
hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=False)[:batch_size]

return (hyps,)


@dataclass
class PyCTCDecodeConfig:
# These arguments cannot be imported from pyctcdecode (optional dependency)
Expand Down Expand Up @@ -906,10 +1017,14 @@ class BeamCTCInferConfig:
preserve_alignments: bool = False
compute_timestamps: bool = False
return_best_hypothesis: bool = True
allow_cuda_graphs: bool = True

beam_alpha: float = 1.0
beam_beta: float = 0.0
kenlm_path: Optional[str] = None
beam_alpha: Optional[float] = None # Deprecated
beam_beta: float = 1.0
beam_threshold: float = 20.0
kenlm_path: Optional[str] = None # Deprecated, default should be None
ngram_lm_alpha: Optional[float] = 1.0
ngram_lm_model: Optional[str] = None

flashlight_cfg: Optional[FlashlightConfig] = field(default_factory=lambda: FlashlightConfig())
pyctcdecode_cfg: Optional[PyCTCDecodeConfig] = field(default_factory=lambda: PyCTCDecodeConfig())
Expand Down
68 changes: 49 additions & 19 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ class AbstractCTCDecoding(ConfidenceMixin):
optional bool, whether to return just the best hypothesis or all of the
hypotheses after beam search has concluded. This flag is set by default.

beam_alpha:
ngram_lm_alpha:
float, the strength of the Language model on the final score of a token.
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

beam_beta:
float, the strength of the sequence length penalty on the final score of a token.
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

kenlm_path:
ngram_lm_model:
str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
If the path is invalid (file is not found at path), will raise a deferred error at the moment
of calculation of beam search, so that users may update / change the decoding strategy
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
self.segment_seperators = self.cfg.get('segment_seperators', ['.', '?', '!'])
self.segment_gap_threshold = self.cfg.get('segment_gap_threshold', None)

possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst']
possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst', 'beam_batch']
if self.cfg.strategy not in possible_strategies:
raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}")

Expand Down Expand Up @@ -267,6 +267,20 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
if self.compute_timestamps is not None:
self.compute_timestamps |= self.preserve_frame_confidence

if self.cfg.strategy in ['flashlight', 'wfst', 'beam_batch', 'pyctcdecode', 'beam']:
if self.cfg.beam.beam_alpha is not None:
logging.warning(
"`beam_alpha` is deprecated and will be removed in a future release. "
"Please use `ngram_lm_alpha` instead."
)
self.cfg.beam.ngram_lm_alpha = self.cfg.beam.beam_alpha
if self.cfg.beam.kenlm_path is not None:
logging.warning(
"`kenlm_path` is deprecated and will be removed in a future release. "
"Please use `ngram_lm_model` instead."
)
self.cfg.beam.ngram_lm_model = self.cfg.beam.kenlm_path

if self.cfg.strategy == 'greedy':
self.decoding = ctc_greedy_decoding.GreedyCTCInfer(
blank_id=self.blank_id,
Expand Down Expand Up @@ -294,9 +308,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
preserve_alignments=self.preserve_alignments,
compute_timestamps=self.compute_timestamps,
beam_alpha=self.cfg.beam.get('beam_alpha', 1.0),
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
kenlm_path=self.cfg.beam.get('kenlm_path', None),
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
)

self.decoding.override_fold_consecutive_value = False
Expand All @@ -310,9 +324,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
preserve_alignments=self.preserve_alignments,
compute_timestamps=self.compute_timestamps,
beam_alpha=self.cfg.beam.get('beam_alpha', 1.0),
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
kenlm_path=self.cfg.beam.get('kenlm_path', None),
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
pyctcdecode_cfg=self.cfg.beam.get('pyctcdecode_cfg', None),
)

Expand All @@ -327,9 +341,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
preserve_alignments=self.preserve_alignments,
compute_timestamps=self.compute_timestamps,
beam_alpha=self.cfg.beam.get('beam_alpha', 1.0),
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
kenlm_path=self.cfg.beam.get('kenlm_path', None),
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
flashlight_cfg=self.cfg.beam.get('flashlight_cfg', None),
)

Expand Down Expand Up @@ -357,6 +371,22 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[

self.decoding.override_fold_consecutive_value = False

elif self.cfg.strategy == 'beam_batch':
self.decoding = ctc_beam_decoding.BeamBatchedCTCInfer(
blank_index=blank_id,
beam_size=self.cfg.beam.get('beam_size', 1),
return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True),
preserve_alignments=self.preserve_alignments,
compute_timestamps=self.compute_timestamps,
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0),
beam_beta=self.cfg.beam.get('beam_beta', 0.0),
beam_threshold=self.cfg.beam.get('beam_threshold', 20.0),
ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None),
allow_cuda_graphs=self.cfg.beam.get('allow_cuda_graphs', True),
)

self.decoding.override_fold_consecutive_value = False

else:
raise ValueError(
f"Incorrect decoding strategy supplied. Must be one of {possible_strategies}\n"
Expand Down Expand Up @@ -1051,15 +1081,15 @@ class CTCDecoding(AbstractCTCDecoding):
optional bool, whether to return just the best hypothesis or all of the
hypotheses after beam search has concluded. This flag is set by default.

beam_alpha:
ngram_lm_alpha:
float, the strength of the Language model on the final score of a token.
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

beam_beta:
float, the strength of the sequence length penalty on the final score of a token.
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

kenlm_path:
ngram_lm_model:
str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
If the path is invalid (file is not found at path), will raise a deferred error at the moment
of calculation of beam search, so that users may update / change the decoding strategy
Expand Down Expand Up @@ -1340,15 +1370,15 @@ class CTCBPEDecoding(AbstractCTCDecoding):
optional bool, whether to return just the best hypothesis or all of the
hypotheses after beam search has concluded. This flag is set by default.

beam_alpha:
ngram_lm_alpha:
float, the strength of the Language model on the final score of a token.
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

beam_beta:
float, the strength of the sequence length penalty on the final score of a token.
final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length.
final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

kenlm_path:
ngram_lm_model:
str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
If the path is invalid (file is not found at path), will raise a deferred error at the moment
of calculation of beam search, so that users may update / change the decoding strategy
Expand Down
Loading
Loading