Skip to content

Commit ff43bc2

Browse files
hainan-xvf14-bertolottiakoumpako3n1gHainan Xu
authored andcommitted
Tdt buffered inference fix (NVIDIA-NeMo#13500)
* added use-fast tokenizer argument (NVIDIA-NeMo#12986) Signed-off-by: Francesco Bertolotti <[email protected]> Co-authored-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: oliver könig <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * ci: Run selective triggering on dockerfiles and dependencies (NVIDIA-NeMo#13493) Signed-off-by: oliver könig <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * fix buffered inference for tdt Signed-off-by: Hainan Xu <[email protected]> * small fixes Signed-off-by: Hainan Xu <[email protected]> * [automodel] fallback FP8 + LCE -> FP8 + CE (NVIDIA-NeMo#13349) * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * make fp8 tests non-optional Signed-off-by: Alexandros Koumparoulis <[email protected]> * switch to gemma Signed-off-by: Alexandros Koumparoulis <[email protected]> --------- Signed-off-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: oliver könig <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Update changelog for `r2.3.0` (NVIDIA-NeMo#13501) * beep boop: Update changelog Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add changelog highlights Signed-off-by: Charlie Truong <[email protected]> --------- Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Charlie Truong <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Charlie Truong <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Update 2.3.0 changelog (NVIDIA-NeMo#13503) Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * ci: Remove trt-llm breakpoint (NVIDIA-NeMo#13499) * tests: Disable flaky test Signed-off-by: oliver könig <[email protected]> * remove breakpoint Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Update 2.3.0 changelog (NVIDIA-NeMo#13504) * Fix 2.3.0 changelog Signed-off-by: Charlie Truong <[email protected]> * Update 2.3.0 changelog Signed-off-by: Charlie Truong <[email protected]> --------- Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Enabling flash decode for float16 precision only (NVIDIA-NeMo#13471) Signed-off-by: Pranav Prashant Thombre <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Fix changelog formatting (NVIDIA-NeMo#13505) Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Updating the long context performance number for B200 (NVIDIA-NeMo#13468) * Add without CP numbers for B200 and merge the captioning texts of both into one. Signed-off-by: Youngeun Kwon <[email protected]> * figure removed Signed-off-by: Youngeun Kwon <[email protected]> --------- Signed-off-by: Youngeun Kwon <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * Autodetect model_type and dtype for deployment using TRT-LLM backend (NVIDIA-NeMo#13209) * Autodetect model_type and dtype for deployment using TRT-LLM backed Signed-off-by: Jan Lasek <[email protected]> * Handling kv_cache_qformat parameter Signed-off-by: Jan Lasek <[email protected]> * Apply isort and black reformatting Signed-off-by: janekl <[email protected]> * Docstring update Signed-off-by: Jan Lasek <[email protected]> --------- Signed-off-by: Jan Lasek <[email protected]> Signed-off-by: janekl <[email protected]> Co-authored-by: janekl <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * remove unused variable Signed-off-by: Hainan Xu <[email protected]> * Apply isort and black reformatting Signed-off-by: hainan-xv <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * add doc string, cleaner way of setting mergo_algo Signed-off-by: Hainan Xu <[email protected]> * Apply isort and black reformatting Signed-off-by: hainan-xv <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * add extra hyena tests (NVIDIA-NeMo#13097) * add extra hyena tests * Apply isort and black reformatting Signed-off-by: JRD971000 <[email protected]> * fix num gpus * keep sft optional --------- Signed-off-by: JRD971000 <[email protected]> Co-authored-by: JRD971000 <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * ci: Add mode files to filter (NVIDIA-NeMo#13517) Signed-off-by: oliver könig <[email protected]> Signed-off-by: Hainan Xu <[email protected]> * change default merge_algo for buffered inference to None Signed-off-by: Hainan Xu <[email protected]> * Apply isort and black reformatting Signed-off-by: hainan-xv <[email protected]> --------- Signed-off-by: Francesco Bertolotti <[email protected]> Signed-off-by: Hainan Xu <[email protected]> Signed-off-by: oliver könig <[email protected]> Signed-off-by: Alexandros Koumparoulis <[email protected]> Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: Pranav Prashant Thombre <[email protected]> Signed-off-by: Youngeun Kwon <[email protected]> Signed-off-by: Jan Lasek <[email protected]> Signed-off-by: janekl <[email protected]> Signed-off-by: hainan-xv <[email protected]> Signed-off-by: JRD971000 <[email protected]> Co-authored-by: Francesco Bertolotti <[email protected]> Co-authored-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: oliver könig <[email protected]> Co-authored-by: Hainan Xu <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Charlie Truong <[email protected]> Co-authored-by: pthombre <[email protected]> Co-authored-by: Youngeun Kwon <[email protected]> Co-authored-by: Jan Lasek <[email protected]> Co-authored-by: janekl <[email protected]> Co-authored-by: hainan-xv <[email protected]> Co-authored-by: Ali Taghibakhshi <[email protected]> Co-authored-by: JRD971000 <[email protected]> Signed-off-by: jianbinc <[email protected]>
1 parent cb40770 commit ff43bc2

File tree

3 files changed

+135
-4
lines changed

3 files changed

+135
-4
lines changed

examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
Buffered inference will use large chunk sizes (5-10 seconds) + some additional buffer for context.
2424
Streaming inference will use small chunk sizes (0.1 to 0.25 seconds) + some additional buffer for context.
2525
26+
Note, currently greedy_batched inferece for TDT is not supported. Decoding strategy will be set to greedy for
27+
TDT automatically.
28+
2629
# Middle Token merge algorithm
2730
2831
python speech_to_text_buffered_infer_rnnt.py \
@@ -73,6 +76,7 @@
7376
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
7477
from nemo.collections.asr.parts.utils.streaming_utils import (
7578
BatchedFrameASRRNNT,
79+
BatchedFrameASRTDT,
7680
LongestCommonSubsequenceBatchedFrameASRRNNT,
7781
)
7882
from nemo.collections.asr.parts.utils.transcribe_utils import (
@@ -135,7 +139,10 @@ class TranscriptionConfig:
135139
stateful_decoding: bool = False # Whether to perform stateful decoding
136140

137141
# Merge algorithm for transducers
138-
merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference.
142+
# choices=['middle', 'lcs', 'tdt'], choice of algorithm to apply during inference.
143+
# if None, we use 'middle' for rnnt and 'tdt' for tdt.
144+
merge_algo: Optional[str] = None
145+
139146
lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments
140147

141148
# Config for word / character error rate calculation
@@ -150,6 +157,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
150157
"""
151158
Transcribes the input audio and can be used to infer long audio files by chunking
152159
them into smaller segments.
160+
Currently, greedy_batched inferece for TDT is not supported. Decoding strategy
161+
will be set to greedy for TDT automatically.
153162
"""
154163
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
155164
torch.set_grad_enabled(False)
@@ -212,9 +221,17 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
212221
asr_model.freeze()
213222
asr_model = asr_model.to(asr_model.device)
214223

224+
model_is_tdt = hasattr(asr_model.loss, '_loss') and type(asr_model.loss._loss).__name__ == 'TDTLossNumba'
225+
if cfg.merge_algo is None:
226+
cfg.merge_algo = "tdt" if model_is_tdt else "middle"
227+
logging.info(f"merge_algo not specified. We use the default algorithm (middle for rnnt and tdt for tdt).")
228+
229+
if model_is_tdt and cfg.merge_algo != "tdt":
230+
raise ValueError("merge_algo must be 'tdt' for TDT models")
231+
215232
# Change Decoding Config
216233
with open_dict(cfg.decoding):
217-
if cfg.stateful_decoding:
234+
if cfg.stateful_decoding or cfg.merge_algo == 'tdt':
218235
cfg.decoding.strategy = "greedy"
219236
else:
220237
cfg.decoding.strategy = "greedy_batch"
@@ -267,6 +284,16 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
267284
# Set the LCS algorithm delay.
268285
frame_asr.lcs_delay = math.floor(((total_buffer - chunk_len)) / model_stride_in_secs)
269286

287+
elif cfg.merge_algo == 'tdt':
288+
frame_asr = BatchedFrameASRTDT(
289+
asr_model=asr_model,
290+
frame_len=chunk_len,
291+
total_buffer=cfg.total_buffer_in_secs,
292+
batch_size=cfg.batch_size,
293+
max_steps_per_timestep=cfg.max_steps_per_timestep,
294+
stateful_decoding=cfg.stateful_decoding,
295+
)
296+
270297
else:
271298
raise ValueError("Invalid choice of merge algorithm for transducer buffered inference.")
272299

nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,10 +2673,12 @@ def _greedy_decode(
26732673

26742674
if self.preserve_alignments:
26752675
# convert Ti-th logits into a torch array
2676-
hypothesis.alignments.append([]) # blank buffer for next timestep
2676+
for i in range(skip):
2677+
hypothesis.alignments.append([]) # blank buffer until next timestep
26772678

26782679
if self.preserve_frame_confidence:
2679-
hypothesis.frame_confidence.append([]) # blank buffer for next timestep
2680+
for i in range(skip):
2681+
hypothesis.frame_confidence.append([]) # blank buffer for next timestep
26802682

26812683
if symbols_added == self.max_symbols:
26822684
time_idx += 1

nemo/collections/asr/parts/utils/streaming_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,108 @@ def greedy_merge(self, preds):
12641264
return hypothesis
12651265

12661266

1267+
class BatchedFrameASRTDT(BatchedFrameASRRNNT):
1268+
"""
1269+
Batched implementation of FrameBatchASR for TDT models, where the batch dimension is independent audio samples.
1270+
It's mostly similar to BatchedFrameASRRNNT with special handling of boundary cases due to the frame-skipping
1271+
resulted by TDT models.
1272+
"""
1273+
1274+
def __init__(
1275+
self,
1276+
asr_model,
1277+
frame_len=1.6,
1278+
total_buffer=4.0,
1279+
batch_size=32,
1280+
max_steps_per_timestep: int = 5,
1281+
stateful_decoding: bool = False,
1282+
tdt_search_boundary: int = 4,
1283+
):
1284+
'''
1285+
Args:
1286+
asr_model: An RNNT model.
1287+
frame_len: frame's duration, seconds.
1288+
total_buffer: duration of total audio chunk size, in seconds.
1289+
batch_size: Number of independent audio samples to process at each step.
1290+
max_steps_per_timestep: Maximum number of tokens (u) to process per acoustic timestep (t).
1291+
stateful_decoding: Boolean whether to enable stateful decoding for preservation of state across buffers.
1292+
tdt_search_boundary: The max number of frames that we search between chunks to match the token at boundary.
1293+
'''
1294+
super().__init__(asr_model, frame_len=frame_len, total_buffer=total_buffer, batch_size=batch_size)
1295+
self.tdt_search_boundary = tdt_search_boundary
1296+
1297+
def transcribe(
1298+
self,
1299+
tokens_per_chunk: int,
1300+
delay: int,
1301+
):
1302+
"""
1303+
Performs "middle token" alignment prediction using the buffered audio chunk.
1304+
"""
1305+
self.infer_logits()
1306+
1307+
self.unmerged = [[] for _ in range(self.batch_size)]
1308+
for idx, alignments in enumerate(self.all_alignments):
1309+
1310+
signal_end_idx = self.frame_bufferer.signal_end_index[idx]
1311+
if signal_end_idx is None:
1312+
raise ValueError("Signal did not end")
1313+
1314+
for a_idx, alignment in enumerate(alignments):
1315+
if delay == len(alignment): # chunk size = buffer size
1316+
offset = 0
1317+
else: # all other cases
1318+
offset = 1
1319+
1320+
longer_alignment = alignment[
1321+
len(alignment)
1322+
- offset
1323+
- delay
1324+
- self.tdt_search_boundary : len(alignment)
1325+
- offset
1326+
- delay
1327+
+ tokens_per_chunk
1328+
]
1329+
1330+
alignment = alignment[
1331+
len(alignment) - offset - delay : len(alignment) - offset - delay + tokens_per_chunk
1332+
]
1333+
1334+
longer_ids, longer_toks = self._alignment_decoder(
1335+
longer_alignment, self.asr_model.tokenizer, self.blank_id
1336+
)
1337+
ids, _ = self._alignment_decoder(alignment, self.asr_model.tokenizer, self.blank_id)
1338+
1339+
if len(longer_ids) > 0 and a_idx < signal_end_idx:
1340+
if a_idx == 0 or len(self.unmerged[idx]) == 0:
1341+
self.unmerged[idx] = inplace_buffer_merge(
1342+
self.unmerged[idx],
1343+
ids,
1344+
delay,
1345+
model=self.asr_model,
1346+
)
1347+
elif len(self.unmerged[idx]) > 0 and len(longer_toks) > 1:
1348+
id_to_match = self.unmerged[idx][-1]
1349+
start = min(len(longer_ids) - len(ids), len(longer_ids) - 1)
1350+
end = -1
1351+
for i in range(start, end, -1):
1352+
if longer_ids[i] == id_to_match:
1353+
ids = longer_ids[i + 1 :]
1354+
break
1355+
1356+
self.unmerged[idx] = inplace_buffer_merge(
1357+
self.unmerged[idx],
1358+
ids,
1359+
delay,
1360+
model=self.asr_model,
1361+
)
1362+
1363+
output = []
1364+
for idx in range(self.batch_size):
1365+
output.append(self.greedy_merge(self.unmerged[idx]))
1366+
return output
1367+
1368+
12671369
class LongestCommonSubsequenceBatchedFrameASRRNNT(BatchedFrameASRRNNT):
12681370
"""
12691371
Implements a token alignment algorithm for text alignment instead of middle token alignment.

0 commit comments

Comments
 (0)