1919from omegaconf import DictConfig
2020from torch .distributions import Categorical
2121
22- from nemo .collections .asr .parts .submodules .ngram_lm import NGramGPULanguageModel
2322from nemo .collections .asr .parts .submodules .token_classifier import TokenClassifier
2423from nemo .collections .asr .parts .utils .asr_confidence_utils import ConfidenceMethodMixin
2524from nemo .collections .common .parts import NEG_INF , mask_padded_tokens
@@ -507,9 +506,9 @@ def _forward(
507506 return tgt
508507
509508
510- class BeamSearchSequenceGeneratorWithNGramLM (BeamSearchSequenceGenerator ):
509+ class BeamSearchSequenceGeneratorWithFusionModels (BeamSearchSequenceGenerator ):
511510 def __init__ (
512- self , embedding , decoder , log_softmax , ngram_lm_model , ngram_lm_alpha = 0.0 , beam_size = 1 , len_pen = 0 , ** kwargs
511+ self , embedding , decoder , log_softmax , fusion_models , fusion_models_alpha , beam_size = 1 , len_pen = 0 , ** kwargs
513512 ):
514513 """
515514 Beam Search sequence generator based on the decoder followed by
@@ -524,30 +523,43 @@ def __init__(
524523 """
525524
526525 super ().__init__ (embedding , decoder , log_softmax , beam_size = beam_size , len_pen = len_pen , ** kwargs )
527- # ngram lm
528- self .ngram_lm_batch = NGramGPULanguageModel . from_file ( lm_path = ngram_lm_model , vocab_size = self . num_tokens )
529- self .ngram_lm_alpha = ngram_lm_alpha
526+
527+ self .fusion_models = fusion_models
528+ self .fusion_models_alpha = fusion_models_alpha
530529
531530 def _forward (
532531 self , decoder_input_ids = None , encoder_hidden_states = None , encoder_input_mask = None , return_beam_scores = False
533532 ):
534533 device = encoder_hidden_states .device
535- # force ngram lm to use the same device as encoder_hidden_states, since current class is not nn.Module instance
536- self .ngram_lm_batch .to (device )
534+ # force fusion models to use the same device as encoder_hidden_states, since current class is not nn.Module instance
535+ for fusion_model in self .fusion_models :
536+ fusion_model .to (device )
537537
538538 tgt , batch_size , max_generation_length = self ._prepare_for_search (decoder_input_ids , encoder_hidden_states )
539- batch_lm_states = self .ngram_lm_batch .get_init_states (batch_size = batch_size , bos = True )
539+
540+ batch_fusion_states_list = [
541+ fusion_model .get_init_states (batch_size = batch_size , bos = True ) for fusion_model in self .fusion_models
542+ ]
543+ batch_fusion_states_candidates_list = []
540544
541545 # generate initial buffer of beam_size prefixes-hypotheses
542546 log_probs , decoder_mems_list = self ._one_step_forward (tgt , encoder_hidden_states , encoder_input_mask , None , 0 )
543- # get ngram lm scores
544- lm_scores , batch_lm_states_candidates = self .ngram_lm_batch .advance (states = batch_lm_states , eos_id = self .eos )
545- log_probs += self .ngram_lm_alpha * lm_scores [:, None , :]
547+ # get fusion models scores
548+ for fusion_model_idx , fusion_model in enumerate (self .fusion_models ):
549+ fusion_scores , batch_fusion_states_candidates = fusion_model .advance (
550+ states = batch_fusion_states_list [fusion_model_idx ], eos_id = self .eos
551+ )
552+ batch_fusion_states_candidates_list .append (batch_fusion_states_candidates )
553+ log_probs += self .fusion_models_alpha [fusion_model_idx ] * fusion_scores [:, None , :]
546554
547555 scores , prefixes = torch .topk (log_probs .permute (0 , 2 , 1 ), self .beam_size , dim = 1 ) # [Batch, Beam, 1]
548- batch_lm_states = batch_lm_states_candidates .gather (dim = 1 , index = prefixes .squeeze (- 1 )).view (
549- - 1
550- ) # [Batch, Beam] -> [Batch*Beam]
556+ for fusion_model_idx , batch_fusion_states_candidates in enumerate (batch_fusion_states_candidates_list ):
557+ batch_fusion_states_list [fusion_model_idx ] = batch_fusion_states_candidates .gather (
558+ dim = 1 , index = prefixes .squeeze (- 1 )
559+ ).view (
560+ - 1
561+ ) # [Batch, Beam] -> [Batch*Beam]
562+
551563 scores , prefixes = scores .view (- 1 , 1 ), prefixes .view (- 1 , 1 ) # [Batch*Beam, 1]
552564
553565 # repeat init target prefixes and cached memory states beam_size times
@@ -583,13 +595,19 @@ def _forward(
583595 log_probs , decoder_mems_list = self ._one_step_forward (
584596 prefixes [:, - 1 :], encoder_hidden_states , encoder_input_mask , decoder_mems_list , i
585597 )
586- lm_scores , batch_lm_states_candidates = self .ngram_lm_batch .advance (
587- states = batch_lm_states , eos_id = self .eos
588- )
589- log_probs += self .ngram_lm_alpha * lm_scores [:, None , :]
598+ for fusion_model_idx , fusion_model in enumerate (self .fusion_models ):
599+ fusion_scores , batch_fusion_states_candidates = fusion_model .advance (
600+ states = batch_fusion_states_list [fusion_model_idx ], eos_id = self .eos
601+ )
602+ log_probs += self .fusion_models_alpha [fusion_model_idx ] * fusion_scores [:, None , :]
603+ batch_fusion_states_candidates_list [fusion_model_idx ] = batch_fusion_states_candidates
590604
591605 scores_i , prefixes_i = torch .topk (log_probs [:, - 1 , :], self .beam_size , dim = - 1 ) # [Batch*Beam, Beam]
592- batch_lm_states = batch_lm_states_candidates .gather (dim = 1 , index = prefixes_i ) # [Batch*Beam, Beam]
606+
607+ for fusion_model_idx , batch_fusion_states_candidates in enumerate (batch_fusion_states_candidates_list ):
608+ batch_fusion_states_list [fusion_model_idx ] = batch_fusion_states_candidates .gather (
609+ dim = 1 , index = prefixes_i
610+ )
593611
594612 # for all prefixes ending with <eos> or <pad> replace generated
595613 # continuations with <pad>
@@ -605,9 +623,12 @@ def _forward(
605623 len_penalties = self .compute_len_penalty (prefixes_len , self .len_pen )
606624 scores = scores / len_penalties
607625 scores , indices_i = torch .topk (scores .view (- 1 , self .beam_size ** 2 ), self .beam_size , dim = 1 ) # [Batch, Beam]
608- batch_lm_states = (
609- batch_lm_states .view (- 1 , self .beam_size ** 2 ).gather (dim = 1 , index = indices_i ).view (- 1 )
610- ) # [Batch, Beam] -> [Batch*Beam]
626+
627+ for fusion_model_idx , batch_fusion_states in enumerate (batch_fusion_states_list ):
628+ batch_fusion_states_list [fusion_model_idx ] = (
629+ batch_fusion_states .view (- 1 , self .beam_size ** 2 ).gather (dim = 1 , index = indices_i ).view (- 1 )
630+ )
631+
611632 scores = scores .view (- 1 , 1 ) * len_penalties # [Batch*Beam, 1]
612633
613634 # select prefixes which correspond to the chosen hypotheses
0 commit comments