Skip to content

Commit c366dbd

Browse files
andrusenkoauCopilotchtruong814artbataev
authored andcommitted
GPU-accelerated Phrase-Boosting (GPU-PB) for AED decoding (NVIDIA-NeMo#14108)
* add initial scripts Signed-off-by: andrusenkoau <[email protected]> * add boosting tree construction Signed-off-by: andrusenkoau <[email protected]> * add pb support to rnnt greedy decoding for python impl only Signed-off-by: andrusenkoau <[email protected]> * first step for the integration of PB for rnnt cuda decoding Signed-off-by: andrusenkoau <[email protected]> * some fixes Signed-off-by: andrusenkoau <[email protected]> * revert changes for rnnt decoding Signed-off-by: andrusenkoau <[email protected]> * add initial step for pb for aed model Signed-off-by: andrusenkoau <[email protected]> * fix a bug with fusion models integration for aed beam decoding Signed-off-by: andrusenkoau <[email protected]> * partial fix for eos score increasing after detection of context phrase Signed-off-by: andrusenkoau <[email protected]> * fix node score in the context graph Signed-off-by: andrusenkoau <[email protected]> * fix backoff weight for end node Signed-off-by: andrusenkoau <[email protected]> * add dummy boosting tree Signed-off-by: andrusenkoau <[email protected]> * add tests for the boosting tree Signed-off-by: andrusenkoau <[email protected]> * add more tests Signed-off-by: andrusenkoau <[email protected]> * minor fixes Signed-off-by: andrusenkoau <[email protected]> * Apply isort and black reformatting Signed-off-by: andrusenkoau <[email protected]> * minor fixes Signed-off-by: andrusenkoau <[email protected]> * Apply isort and black reformatting Signed-off-by: andrusenkoau <[email protected]> * minor fixes Signed-off-by: andrusenkoau <[email protected]> * fix not closed file Signed-off-by: andrusenkoau <[email protected]> * minor fix Signed-off-by: andrusenkoau <[email protected]> * Apply isort and black reformatting Signed-off-by: andrusenkoau <[email protected]> * Update tests/collections/asr/test_boosting_tree.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> * Update tests/collections/asr/decoding/test_multi_task_decoding.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> * Update scripts/asr_context_biasing/build_gpu_boosting_tree.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> * Update scripts/asr_context_biasing/compute_key_words_fscore.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> * add asr model path + name Signed-off-by: andrusenkoau <[email protected]> * add bt config Signed-off-by: andrusenkoau <[email protected]> * add loading boosting tree from config Signed-off-by: andrusenkoau <[email protected]> * add a new test for test_boosting_tree_model_from_config Signed-off-by: andrusenkoau <[email protected]> * Apply isort and black reformatting Signed-off-by: andrusenkoau <[email protected]> * minor fix Signed-off-by: andrusenkoau <[email protected]> * pr fixes Signed-off-by: andrusenkoau <[email protected]> * Apply isort and black reformatting Signed-off-by: andrusenkoau <[email protected]> * minor fixes Signed-off-by: andrusenkoau <[email protected]> * minor fix Signed-off-by: andrusenkoau <[email protected]> * Update scripts/asr_context_biasing/build_gpu_boosting_tree.py Co-authored-by: Vladimir Bataev <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> * Update scripts/asr_context_biasing/build_gpu_boosting_tree.py Co-authored-by: Vladimir Bataev <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> * Apply isort and black reformatting Signed-off-by: andrusenkoau <[email protected]> --------- Signed-off-by: andrusenkoau <[email protected]> Signed-off-by: andrusenkoau <[email protected]> Signed-off-by: Andrei Andrusenko <[email protected]> Co-authored-by: andrusenkoau <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Charlie Truong <[email protected]> Co-authored-by: Vladimir Bataev <[email protected]>
1 parent c94b5f0 commit c366dbd

File tree

12 files changed

+1473
-51
lines changed

12 files changed

+1473
-51
lines changed

nemo/collections/asr/modules/transformer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder
2525
from nemo.collections.asr.modules.transformer.transformer_generators import (
2626
BeamSearchSequenceGenerator,
27+
BeamSearchSequenceGeneratorWithFusionModels,
2728
BeamSearchSequenceGeneratorWithLanguageModel,
28-
BeamSearchSequenceGeneratorWithNGramLM,
2929
EnsembleBeamSearchSequenceGenerator,
3030
GreedySequenceGenerator,
3131
TopKSequenceGenerator,
@@ -44,7 +44,7 @@
4444
"TransformerEncoder",
4545
"BeamSearchSequenceGenerator",
4646
"BeamSearchSequenceGeneratorWithLanguageModel",
47-
"BeamSearchSequenceGeneratorWithNGramLM",
47+
"BeamSearchSequenceGeneratorWithFusionModels",
4848
"EnsembleBeamSearchSequenceGenerator",
4949
"GreedySequenceGenerator",
5050
"TopKSequenceGenerator",

nemo/collections/asr/modules/transformer/transformer_generators.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from omegaconf import DictConfig
2020
from torch.distributions import Categorical
2121

22-
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
2322
from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier
2423
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
2524
from nemo.collections.common.parts import NEG_INF, mask_padded_tokens
@@ -507,9 +506,9 @@ def _forward(
507506
return tgt
508507

509508

510-
class BeamSearchSequenceGeneratorWithNGramLM(BeamSearchSequenceGenerator):
509+
class BeamSearchSequenceGeneratorWithFusionModels(BeamSearchSequenceGenerator):
511510
def __init__(
512-
self, embedding, decoder, log_softmax, ngram_lm_model, ngram_lm_alpha=0.0, beam_size=1, len_pen=0, **kwargs
511+
self, embedding, decoder, log_softmax, fusion_models, fusion_models_alpha, beam_size=1, len_pen=0, **kwargs
513512
):
514513
"""
515514
Beam Search sequence generator based on the decoder followed by
@@ -524,30 +523,43 @@ def __init__(
524523
"""
525524

526525
super().__init__(embedding, decoder, log_softmax, beam_size=beam_size, len_pen=len_pen, **kwargs)
527-
# ngram lm
528-
self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self.num_tokens)
529-
self.ngram_lm_alpha = ngram_lm_alpha
526+
527+
self.fusion_models = fusion_models
528+
self.fusion_models_alpha = fusion_models_alpha
530529

531530
def _forward(
532531
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
533532
):
534533
device = encoder_hidden_states.device
535-
# force ngram lm to use the same device as encoder_hidden_states, since current class is not nn.Module instance
536-
self.ngram_lm_batch.to(device)
534+
# force fusion models to use the same device as encoder_hidden_states, since current class is not nn.Module instance
535+
for fusion_model in self.fusion_models:
536+
fusion_model.to(device)
537537

538538
tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)
539-
batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True)
539+
540+
batch_fusion_states_list = [
541+
fusion_model.get_init_states(batch_size=batch_size, bos=True) for fusion_model in self.fusion_models
542+
]
543+
batch_fusion_states_candidates_list = []
540544

541545
# generate initial buffer of beam_size prefixes-hypotheses
542546
log_probs, decoder_mems_list = self._one_step_forward(tgt, encoder_hidden_states, encoder_input_mask, None, 0)
543-
# get ngram lm scores
544-
lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states, eos_id=self.eos)
545-
log_probs += self.ngram_lm_alpha * lm_scores[:, None, :]
547+
# get fusion models scores
548+
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
549+
fusion_scores, batch_fusion_states_candidates = fusion_model.advance(
550+
states=batch_fusion_states_list[fusion_model_idx], eos_id=self.eos
551+
)
552+
batch_fusion_states_candidates_list.append(batch_fusion_states_candidates)
553+
log_probs += self.fusion_models_alpha[fusion_model_idx] * fusion_scores[:, None, :]
546554

547555
scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1) # [Batch, Beam, 1]
548-
batch_lm_states = batch_lm_states_candidates.gather(dim=1, index=prefixes.squeeze(-1)).view(
549-
-1
550-
) # [Batch, Beam] -> [Batch*Beam]
556+
for fusion_model_idx, batch_fusion_states_candidates in enumerate(batch_fusion_states_candidates_list):
557+
batch_fusion_states_list[fusion_model_idx] = batch_fusion_states_candidates.gather(
558+
dim=1, index=prefixes.squeeze(-1)
559+
).view(
560+
-1
561+
) # [Batch, Beam] -> [Batch*Beam]
562+
551563
scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1) # [Batch*Beam, 1]
552564

553565
# repeat init target prefixes and cached memory states beam_size times
@@ -583,13 +595,19 @@ def _forward(
583595
log_probs, decoder_mems_list = self._one_step_forward(
584596
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
585597
)
586-
lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(
587-
states=batch_lm_states, eos_id=self.eos
588-
)
589-
log_probs += self.ngram_lm_alpha * lm_scores[:, None, :]
598+
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
599+
fusion_scores, batch_fusion_states_candidates = fusion_model.advance(
600+
states=batch_fusion_states_list[fusion_model_idx], eos_id=self.eos
601+
)
602+
log_probs += self.fusion_models_alpha[fusion_model_idx] * fusion_scores[:, None, :]
603+
batch_fusion_states_candidates_list[fusion_model_idx] = batch_fusion_states_candidates
590604

591605
scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1) # [Batch*Beam, Beam]
592-
batch_lm_states = batch_lm_states_candidates.gather(dim=1, index=prefixes_i) # [Batch*Beam, Beam]
606+
607+
for fusion_model_idx, batch_fusion_states_candidates in enumerate(batch_fusion_states_candidates_list):
608+
batch_fusion_states_list[fusion_model_idx] = batch_fusion_states_candidates.gather(
609+
dim=1, index=prefixes_i
610+
)
593611

594612
# for all prefixes ending with <eos> or <pad> replace generated
595613
# continuations with <pad>
@@ -605,9 +623,12 @@ def _forward(
605623
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
606624
scores = scores / len_penalties
607625
scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) # [Batch, Beam]
608-
batch_lm_states = (
609-
batch_lm_states.view(-1, self.beam_size**2).gather(dim=1, index=indices_i).view(-1)
610-
) # [Batch, Beam] -> [Batch*Beam]
626+
627+
for fusion_model_idx, batch_fusion_states in enumerate(batch_fusion_states_list):
628+
batch_fusion_states_list[fusion_model_idx] = (
629+
batch_fusion_states.view(-1, self.beam_size**2).gather(dim=1, index=indices_i).view(-1)
630+
)
631+
611632
scores = scores.view(-1, 1) * len_penalties # [Batch*Beam, 1]
612633

613634
# select prefixes which correspond to the chosen hypotheses

nemo/collections/asr/parts/context_biasing/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from nemo.collections.asr.parts.context_biasing.boosting_graph_batched import (
16+
BoostingTreeModelConfig,
17+
GPUBoostingTreeModel,
18+
)
1519
from nemo.collections.asr.parts.context_biasing.context_biasing_utils import (
1620
compute_fscore,
1721
merge_alignment_with_ws_hyps,
1822
)
1923
from nemo.collections.asr.parts.context_biasing.context_graph_ctc import ContextGraphCTC
2024
from nemo.collections.asr.parts.context_biasing.ctc_based_word_spotter import run_word_spotter
25+
26+
__all__ = [
27+
"GPUBoostingTreeModel",
28+
"BoostingTreeModelConfig",
29+
"compute_fscore",
30+
"merge_alignment_with_ws_hyps",
31+
"ContextGraphCTC",
32+
"run_word_spotter",
33+
]

0 commit comments

Comments
 (0)