Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
57fabaa
added use-fast tokenizer argument (#12986)
f14-bertolotti May 8, 2025
d15485e
ci: Run selective triggering on dockerfiles and dependencies (#13493)
ko3n1g May 8, 2025
465ab59
fix buffered inference for tdt
May 8, 2025
f9684be
small fixes
May 8, 2025
c2fedfb
[automodel] fallback FP8 + LCE -> FP8 + CE (#13349)
akoumpa May 8, 2025
b74a045
Update changelog for `r2.3.0` (#13501)
github-actions[bot] May 8, 2025
7cbf683
Update 2.3.0 changelog (#13503)
chtruong814 May 8, 2025
c2b0d46
ci: Remove trt-llm breakpoint (#13499)
ko3n1g May 8, 2025
a2aabc8
Update 2.3.0 changelog (#13504)
chtruong814 May 8, 2025
896f4b6
Enabling flash decode for float16 precision only (#13471)
pthombre May 8, 2025
93a2483
Fix changelog formatting (#13505)
chtruong814 May 8, 2025
a2c4cc8
Updating the long context performance number for B200 (#13468)
youngeunkwon0405 May 8, 2025
2d11f09
Autodetect model_type and dtype for deployment using TRT-LLM backend …
janekl May 9, 2025
1b43d64
remove unused variable
May 9, 2025
1da5fee
Apply isort and black reformatting
hainan-xv May 8, 2025
3be53e4
add doc string, cleaner way of setting mergo_algo
May 9, 2025
d170dc8
Apply isort and black reformatting
hainan-xv May 9, 2025
64d7603
add extra hyena tests (#13097)
JRD971000 May 9, 2025
62284b6
ci: Add mode files to filter (#13517)
ko3n1g May 9, 2025
ebe7743
Merge branch 'main' of https://github.com/NVIDIA/NeMo into tdt_buffer…
May 10, 2025
47ce9a2
change default merge_algo for buffered inference to None
May 12, 2025
d46ae4c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into tdt_buffer…
May 12, 2025
a32fd25
Apply isort and black reformatting
hainan-xv May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
Buffered inference will use large chunk sizes (5-10 seconds) + some additional buffer for context.
Streaming inference will use small chunk sizes (0.1 to 0.25 seconds) + some additional buffer for context.

Note, currently greedy_batched inferece for TDT is not supported. Decoding strategy will be set to greedy for
TDT automatically.

# Middle Token merge algorithm

python speech_to_text_buffered_infer_rnnt.py \
Expand Down Expand Up @@ -73,6 +76,7 @@
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import (
BatchedFrameASRRNNT,
BatchedFrameASRTDT,
LongestCommonSubsequenceBatchedFrameASRRNNT,
)
from nemo.collections.asr.parts.utils.transcribe_utils import (
Expand Down Expand Up @@ -135,7 +139,10 @@ class TranscriptionConfig:
stateful_decoding: bool = False # Whether to perform stateful decoding

# Merge algorithm for transducers
merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference.
# choices=['middle', 'lcs', 'tdt'], choice of algorithm to apply during inference.
# if None, we use 'middle' for rnnt and 'tdt' for tdt.
merge_algo: Optional[str] = None

lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments

# Config for word / character error rate calculation
Expand All @@ -150,6 +157,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
"""
Transcribes the input audio and can be used to infer long audio files by chunking
them into smaller segments.
Currently, greedy_batched inferece for TDT is not supported. Decoding strategy
will be set to greedy for TDT automatically.
"""
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -212,9 +221,17 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
asr_model.freeze()
asr_model = asr_model.to(asr_model.device)

model_is_tdt = hasattr(asr_model.loss, '_loss') and type(asr_model.loss._loss).__name__ == 'TDTLossNumba'
if cfg.merge_algo is None:
cfg.merge_algo = "tdt" if model_is_tdt else "middle"
logging.info(f"merge_algo not specified. We use the default algorithm (middle for rnnt and tdt for tdt).")

if model_is_tdt and cfg.merge_algo != "tdt":
raise ValueError("merge_algo must be 'tdt' for TDT models")

# Change Decoding Config
with open_dict(cfg.decoding):
if cfg.stateful_decoding:
if cfg.stateful_decoding or cfg.merge_algo == 'tdt':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this info to the doc string above and also add to script usage at the top of this script.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

cfg.decoding.strategy = "greedy"
else:
cfg.decoding.strategy = "greedy_batch"
Expand Down Expand Up @@ -267,6 +284,16 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
# Set the LCS algorithm delay.
frame_asr.lcs_delay = math.floor(((total_buffer - chunk_len)) / model_stride_in_secs)

elif cfg.merge_algo == 'tdt':
frame_asr = BatchedFrameASRTDT(
asr_model=asr_model,
frame_len=chunk_len,
total_buffer=cfg.total_buffer_in_secs,
batch_size=cfg.batch_size,
max_steps_per_timestep=cfg.max_steps_per_timestep,
stateful_decoding=cfg.stateful_decoding,
)

else:
raise ValueError("Invalid choice of merge algorithm for transducer buffered inference.")

Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,10 +2673,12 @@ def _greedy_decode(

if self.preserve_alignments:
# convert Ti-th logits into a torch array
hypothesis.alignments.append([]) # blank buffer for next timestep
for i in range(skip):
hypothesis.alignments.append([]) # blank buffer until next timestep

if self.preserve_frame_confidence:
hypothesis.frame_confidence.append([]) # blank buffer for next timestep
for i in range(skip):
hypothesis.frame_confidence.append([]) # blank buffer for next timestep

if symbols_added == self.max_symbols:
time_idx += 1
Expand Down
102 changes: 102 additions & 0 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,108 @@ def greedy_merge(self, preds):
return hypothesis


class BatchedFrameASRTDT(BatchedFrameASRRNNT):
"""
Batched implementation of FrameBatchASR for TDT models, where the batch dimension is independent audio samples.
It's mostly similar to BatchedFrameASRRNNT with special handling of boundary cases due to the frame-skipping
resulted by TDT models.
"""

def __init__(
self,
asr_model,
frame_len=1.6,
total_buffer=4.0,
batch_size=32,
max_steps_per_timestep: int = 5,
stateful_decoding: bool = False,
tdt_search_boundary: int = 4,
):
'''
Args:
asr_model: An RNNT model.
frame_len: frame's duration, seconds.
total_buffer: duration of total audio chunk size, in seconds.
batch_size: Number of independent audio samples to process at each step.
max_steps_per_timestep: Maximum number of tokens (u) to process per acoustic timestep (t).
stateful_decoding: Boolean whether to enable stateful decoding for preservation of state across buffers.
tdt_search_boundary: The max number of frames that we search between chunks to match the token at boundary.
'''
super().__init__(asr_model, frame_len=frame_len, total_buffer=total_buffer, batch_size=batch_size)
self.tdt_search_boundary = tdt_search_boundary

def transcribe(
self,
tokens_per_chunk: int,
delay: int,
):
"""
Performs "middle token" alignment prediction using the buffered audio chunk.
"""
self.infer_logits()

self.unmerged = [[] for _ in range(self.batch_size)]
for idx, alignments in enumerate(self.all_alignments):

signal_end_idx = self.frame_bufferer.signal_end_index[idx]
if signal_end_idx is None:
raise ValueError("Signal did not end")

for a_idx, alignment in enumerate(alignments):
if delay == len(alignment): # chunk size = buffer size
offset = 0
else: # all other cases
offset = 1

longer_alignment = alignment[
len(alignment)
- offset
- delay
- self.tdt_search_boundary : len(alignment)
- offset
- delay
+ tokens_per_chunk
]

alignment = alignment[
len(alignment) - offset - delay : len(alignment) - offset - delay + tokens_per_chunk
]

longer_ids, longer_toks = self._alignment_decoder(
longer_alignment, self.asr_model.tokenizer, self.blank_id
)
ids, _ = self._alignment_decoder(alignment, self.asr_model.tokenizer, self.blank_id)

if len(longer_ids) > 0 and a_idx < signal_end_idx:
if a_idx == 0 or len(self.unmerged[idx]) == 0:
self.unmerged[idx] = inplace_buffer_merge(
self.unmerged[idx],
ids,
delay,
model=self.asr_model,
)
elif len(self.unmerged[idx]) > 0 and len(longer_toks) > 1:
id_to_match = self.unmerged[idx][-1]
start = min(len(longer_ids) - len(ids), len(longer_ids) - 1)
end = -1
for i in range(start, end, -1):
if longer_ids[i] == id_to_match:
ids = longer_ids[i + 1 :]
break

self.unmerged[idx] = inplace_buffer_merge(
self.unmerged[idx],
ids,
delay,
model=self.asr_model,
)

output = []
for idx in range(self.batch_size):
output.append(self.greedy_merge(self.unmerged[idx]))
return output


class LongestCommonSubsequenceBatchedFrameASRRNNT(BatchedFrameASRRNNT):
"""
Implements a token alignment algorithm for text alignment instead of middle token alignment.
Expand Down
Loading