Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9b987c6
add initial scripts
andrusenkoau Jun 17, 2025
842a82e
add boosting tree construction
andrusenkoau Jun 17, 2025
0a89773
add pb support to rnnt greedy decoding for python impl only
andrusenkoau Jun 18, 2025
70a89e4
first step for the integration of PB for rnnt cuda decoding
andrusenkoau Jun 19, 2025
4a0d4ab
some fixes
andrusenkoau Jun 25, 2025
52358f0
Merge branch 'main' of github.com:andrusenkoau/NeMo into gpu_boosting…
andrusenkoau Jun 27, 2025
f19c2a9
revert changes for rnnt decoding
andrusenkoau Jun 27, 2025
307ed74
add initial step for pb for aed model
andrusenkoau Jun 27, 2025
47fd78f
fix a bug with fusion models integration for aed beam decoding
andrusenkoau Jun 30, 2025
f8739d3
partial fix for eos score increasing after detection of context phrase
andrusenkoau Jun 30, 2025
194fa5b
fix node score in the context graph
andrusenkoau Jul 1, 2025
47fc763
fix backoff weight for end node
andrusenkoau Jul 1, 2025
322c41a
add dummy boosting tree
andrusenkoau Jul 1, 2025
b671ec6
add tests for the boosting tree
andrusenkoau Jul 1, 2025
b2bfd43
add more tests
andrusenkoau Jul 1, 2025
ebf36d6
Merge branch 'main' of github.com:andrusenkoau/NeMo into gpu_boosting…
andrusenkoau Jul 2, 2025
8039e9c
minor fixes
andrusenkoau Jul 2, 2025
c90271e
Apply isort and black reformatting
andrusenkoau Jul 2, 2025
531c97e
minor fixes
andrusenkoau Jul 2, 2025
bba5248
Apply isort and black reformatting
andrusenkoau Jul 2, 2025
ad24612
minor fixes
andrusenkoau Jul 2, 2025
c003e21
fix not closed file
andrusenkoau Jul 2, 2025
13bac85
minor fix
andrusenkoau Jul 2, 2025
69b5188
Apply isort and black reformatting
andrusenkoau Jul 2, 2025
6ce38ec
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 2, 2025
358a3a1
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 3, 2025
5626fca
Update tests/collections/asr/test_boosting_tree.py
andrusenkoau Jul 3, 2025
88af77c
Update tests/collections/asr/decoding/test_multi_task_decoding.py
andrusenkoau Jul 3, 2025
d3d5cc6
Update scripts/asr_context_biasing/build_gpu_boosting_tree.py
andrusenkoau Jul 3, 2025
55d2230
Update scripts/asr_context_biasing/compute_key_words_fscore.py
andrusenkoau Jul 3, 2025
2bf34a7
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 3, 2025
34252e6
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 4, 2025
4d2bbe2
add asr model path + name
andrusenkoau Jul 4, 2025
81e82e8
add bt config
andrusenkoau Jul 4, 2025
e9463ba
add loading boosting tree from config
andrusenkoau Jul 4, 2025
bb0de6a
add a new test for test_boosting_tree_model_from_config
andrusenkoau Jul 4, 2025
afb7ad8
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 4, 2025
2f26750
Apply isort and black reformatting
andrusenkoau Jul 4, 2025
e90a017
minor fix
andrusenkoau Jul 4, 2025
7d0ce2a
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 7, 2025
f1c4ccc
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 10, 2025
f832213
pr fixes
andrusenkoau Jul 10, 2025
84e4bf9
Apply isort and black reformatting
andrusenkoau Jul 10, 2025
dcac2f1
minor fixes
andrusenkoau Jul 10, 2025
472f0ef
minor fix
andrusenkoau Jul 10, 2025
9c1d455
Merge branch 'main' into gpu_boosting_aed_pr
chtruong814 Jul 11, 2025
0ddae81
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 11, 2025
f2223e5
Merge branch 'main' into gpu_boosting_aed_pr
andrusenkoau Jul 11, 2025
355e00f
Update scripts/asr_context_biasing/build_gpu_boosting_tree.py
andrusenkoau Jul 11, 2025
1527622
Update scripts/asr_context_biasing/build_gpu_boosting_tree.py
andrusenkoau Jul 11, 2025
777f377
Apply isort and black reformatting
andrusenkoau Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/asr/modules/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder
from nemo.collections.asr.modules.transformer.transformer_generators import (
BeamSearchSequenceGenerator,
BeamSearchSequenceGeneratorWithFusionModels,
BeamSearchSequenceGeneratorWithLanguageModel,
BeamSearchSequenceGeneratorWithNGramLM,
EnsembleBeamSearchSequenceGenerator,
GreedySequenceGenerator,
TopKSequenceGenerator,
Expand All @@ -44,7 +44,7 @@
"TransformerEncoder",
"BeamSearchSequenceGenerator",
"BeamSearchSequenceGeneratorWithLanguageModel",
"BeamSearchSequenceGeneratorWithNGramLM",
"BeamSearchSequenceGeneratorWithFusionModels",
"EnsembleBeamSearchSequenceGenerator",
"GreedySequenceGenerator",
"TopKSequenceGenerator",
Expand Down
67 changes: 44 additions & 23 deletions nemo/collections/asr/modules/transformer/transformer_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from omegaconf import DictConfig
from torch.distributions import Categorical

from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
from nemo.collections.common.parts import NEG_INF, mask_padded_tokens
Expand Down Expand Up @@ -507,9 +506,9 @@ def _forward(
return tgt


class BeamSearchSequenceGeneratorWithNGramLM(BeamSearchSequenceGenerator):
class BeamSearchSequenceGeneratorWithFusionModels(BeamSearchSequenceGenerator):
def __init__(
self, embedding, decoder, log_softmax, ngram_lm_model, ngram_lm_alpha=0.0, beam_size=1, len_pen=0, **kwargs
self, embedding, decoder, log_softmax, fusion_models, fusion_models_alpha, beam_size=1, len_pen=0, **kwargs
):
"""
Beam Search sequence generator based on the decoder followed by
Expand All @@ -524,30 +523,43 @@ def __init__(
"""

super().__init__(embedding, decoder, log_softmax, beam_size=beam_size, len_pen=len_pen, **kwargs)
# ngram lm
self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self.num_tokens)
self.ngram_lm_alpha = ngram_lm_alpha

self.fusion_models = fusion_models
self.fusion_models_alpha = fusion_models_alpha

def _forward(
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
):
device = encoder_hidden_states.device
# force ngram lm to use the same device as encoder_hidden_states, since current class is not nn.Module instance
self.ngram_lm_batch.to(device)
# force fusion models to use the same device as encoder_hidden_states, since current class is not nn.Module instance
for fusion_model in self.fusion_models:
fusion_model.to(device)

tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)
batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True)

batch_fusion_states_list = [
fusion_model.get_init_states(batch_size=batch_size, bos=True) for fusion_model in self.fusion_models
]
batch_fusion_states_candidates_list = []

# generate initial buffer of beam_size prefixes-hypotheses
log_probs, decoder_mems_list = self._one_step_forward(tgt, encoder_hidden_states, encoder_input_mask, None, 0)
# get ngram lm scores
lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states, eos_id=self.eos)
log_probs += self.ngram_lm_alpha * lm_scores[:, None, :]
# get fusion models scores
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
fusion_scores, batch_fusion_states_candidates = fusion_model.advance(
states=batch_fusion_states_list[fusion_model_idx], eos_id=self.eos
)
batch_fusion_states_candidates_list.append(batch_fusion_states_candidates)
log_probs += self.fusion_models_alpha[fusion_model_idx] * fusion_scores[:, None, :]

scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1) # [Batch, Beam, 1]
batch_lm_states = batch_lm_states_candidates.gather(dim=1, index=prefixes.squeeze(-1)).view(
-1
) # [Batch, Beam] -> [Batch*Beam]
for fusion_model_idx, batch_fusion_states_candidates in enumerate(batch_fusion_states_candidates_list):
batch_fusion_states_list[fusion_model_idx] = batch_fusion_states_candidates.gather(
dim=1, index=prefixes.squeeze(-1)
).view(
-1
) # [Batch, Beam] -> [Batch*Beam]

scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1) # [Batch*Beam, 1]

# repeat init target prefixes and cached memory states beam_size times
Expand Down Expand Up @@ -583,13 +595,19 @@ def _forward(
log_probs, decoder_mems_list = self._one_step_forward(
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
)
lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(
states=batch_lm_states, eos_id=self.eos
)
log_probs += self.ngram_lm_alpha * lm_scores[:, None, :]
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
fusion_scores, batch_fusion_states_candidates = fusion_model.advance(
states=batch_fusion_states_list[fusion_model_idx], eos_id=self.eos
)
log_probs += self.fusion_models_alpha[fusion_model_idx] * fusion_scores[:, None, :]
batch_fusion_states_candidates_list[fusion_model_idx] = batch_fusion_states_candidates

scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1) # [Batch*Beam, Beam]
batch_lm_states = batch_lm_states_candidates.gather(dim=1, index=prefixes_i) # [Batch*Beam, Beam]

for fusion_model_idx, batch_fusion_states_candidates in enumerate(batch_fusion_states_candidates_list):
batch_fusion_states_list[fusion_model_idx] = batch_fusion_states_candidates.gather(
dim=1, index=prefixes_i
)

# for all prefixes ending with <eos> or <pad> replace generated
# continuations with <pad>
Expand All @@ -605,9 +623,12 @@ def _forward(
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) # [Batch, Beam]
batch_lm_states = (
batch_lm_states.view(-1, self.beam_size**2).gather(dim=1, index=indices_i).view(-1)
) # [Batch, Beam] -> [Batch*Beam]

for fusion_model_idx, batch_fusion_states in enumerate(batch_fusion_states_list):
batch_fusion_states_list[fusion_model_idx] = (
batch_fusion_states.view(-1, self.beam_size**2).gather(dim=1, index=indices_i).view(-1)
)

scores = scores.view(-1, 1) * len_penalties # [Batch*Beam, 1]

# select prefixes which correspond to the chosen hypotheses
Expand Down
9 changes: 9 additions & 0 deletions nemo/collections/asr/parts/context_biasing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.asr.parts.context_biasing.boosting_graph_batched import GPUBoostingTreeModel
from nemo.collections.asr.parts.context_biasing.context_biasing_utils import (
compute_fscore,
merge_alignment_with_ws_hyps,
)
from nemo.collections.asr.parts.context_biasing.context_graph_ctc import ContextGraphCTC
from nemo.collections.asr.parts.context_biasing.ctc_based_word_spotter import run_word_spotter

__all__ = [
"GPUBoostingTreeModel",
"compute_fscore",
"merge_alignment_with_ws_hyps",
"ContextGraphCTC",
"run_word_spotter",
]
Loading
Loading