diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py index c31fa2b9d812..64f7b0ee8260 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -23,6 +23,9 @@ Buffered inference will use large chunk sizes (5-10 seconds) + some additional buffer for context. Streaming inference will use small chunk sizes (0.1 to 0.25 seconds) + some additional buffer for context. +Note, currently greedy_batched inferece for TDT is not supported. Decoding strategy will be set to greedy for +TDT automatically. + # Middle Token merge algorithm python speech_to_text_buffered_infer_rnnt.py \ @@ -73,6 +76,7 @@ from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.streaming_utils import ( BatchedFrameASRRNNT, + BatchedFrameASRTDT, LongestCommonSubsequenceBatchedFrameASRRNNT, ) from nemo.collections.asr.parts.utils.transcribe_utils import ( @@ -135,7 +139,10 @@ class TranscriptionConfig: stateful_decoding: bool = False # Whether to perform stateful decoding # Merge algorithm for transducers - merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference. + # choices=['middle', 'lcs', 'tdt'], choice of algorithm to apply during inference. + # if None, we use 'middle' for rnnt and 'tdt' for tdt. + merge_algo: Optional[str] = None + lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments # Config for word / character error rate calculation @@ -150,6 +157,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: """ Transcribes the input audio and can be used to infer long audio files by chunking them into smaller segments. + Currently, greedy_batched inferece for TDT is not supported. Decoding strategy + will be set to greedy for TDT automatically. """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -212,9 +221,17 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: asr_model.freeze() asr_model = asr_model.to(asr_model.device) + model_is_tdt = hasattr(asr_model.loss, '_loss') and type(asr_model.loss._loss).__name__ == 'TDTLossNumba' + if cfg.merge_algo is None: + cfg.merge_algo = "tdt" if model_is_tdt else "middle" + logging.info(f"merge_algo not specified. We use the default algorithm (middle for rnnt and tdt for tdt).") + + if model_is_tdt and cfg.merge_algo != "tdt": + raise ValueError("merge_algo must be 'tdt' for TDT models") + # Change Decoding Config with open_dict(cfg.decoding): - if cfg.stateful_decoding: + if cfg.stateful_decoding or cfg.merge_algo == 'tdt': cfg.decoding.strategy = "greedy" else: cfg.decoding.strategy = "greedy_batch" @@ -267,6 +284,16 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Set the LCS algorithm delay. frame_asr.lcs_delay = math.floor(((total_buffer - chunk_len)) / model_stride_in_secs) + elif cfg.merge_algo == 'tdt': + frame_asr = BatchedFrameASRTDT( + asr_model=asr_model, + frame_len=chunk_len, + total_buffer=cfg.total_buffer_in_secs, + batch_size=cfg.batch_size, + max_steps_per_timestep=cfg.max_steps_per_timestep, + stateful_decoding=cfg.stateful_decoding, + ) + else: raise ValueError("Invalid choice of merge algorithm for transducer buffered inference.") diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index b5495802d2fd..c281fae0be55 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2673,10 +2673,12 @@ def _greedy_decode( if self.preserve_alignments: # convert Ti-th logits into a torch array - hypothesis.alignments.append([]) # blank buffer for next timestep + for i in range(skip): + hypothesis.alignments.append([]) # blank buffer until next timestep if self.preserve_frame_confidence: - hypothesis.frame_confidence.append([]) # blank buffer for next timestep + for i in range(skip): + hypothesis.frame_confidence.append([]) # blank buffer for next timestep if symbols_added == self.max_symbols: time_idx += 1 diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 46d76dd1c9fa..8982f350a13c 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -1264,6 +1264,108 @@ def greedy_merge(self, preds): return hypothesis +class BatchedFrameASRTDT(BatchedFrameASRRNNT): + """ + Batched implementation of FrameBatchASR for TDT models, where the batch dimension is independent audio samples. + It's mostly similar to BatchedFrameASRRNNT with special handling of boundary cases due to the frame-skipping + resulted by TDT models. + """ + + def __init__( + self, + asr_model, + frame_len=1.6, + total_buffer=4.0, + batch_size=32, + max_steps_per_timestep: int = 5, + stateful_decoding: bool = False, + tdt_search_boundary: int = 4, + ): + ''' + Args: + asr_model: An RNNT model. + frame_len: frame's duration, seconds. + total_buffer: duration of total audio chunk size, in seconds. + batch_size: Number of independent audio samples to process at each step. + max_steps_per_timestep: Maximum number of tokens (u) to process per acoustic timestep (t). + stateful_decoding: Boolean whether to enable stateful decoding for preservation of state across buffers. + tdt_search_boundary: The max number of frames that we search between chunks to match the token at boundary. + ''' + super().__init__(asr_model, frame_len=frame_len, total_buffer=total_buffer, batch_size=batch_size) + self.tdt_search_boundary = tdt_search_boundary + + def transcribe( + self, + tokens_per_chunk: int, + delay: int, + ): + """ + Performs "middle token" alignment prediction using the buffered audio chunk. + """ + self.infer_logits() + + self.unmerged = [[] for _ in range(self.batch_size)] + for idx, alignments in enumerate(self.all_alignments): + + signal_end_idx = self.frame_bufferer.signal_end_index[idx] + if signal_end_idx is None: + raise ValueError("Signal did not end") + + for a_idx, alignment in enumerate(alignments): + if delay == len(alignment): # chunk size = buffer size + offset = 0 + else: # all other cases + offset = 1 + + longer_alignment = alignment[ + len(alignment) + - offset + - delay + - self.tdt_search_boundary : len(alignment) + - offset + - delay + + tokens_per_chunk + ] + + alignment = alignment[ + len(alignment) - offset - delay : len(alignment) - offset - delay + tokens_per_chunk + ] + + longer_ids, longer_toks = self._alignment_decoder( + longer_alignment, self.asr_model.tokenizer, self.blank_id + ) + ids, _ = self._alignment_decoder(alignment, self.asr_model.tokenizer, self.blank_id) + + if len(longer_ids) > 0 and a_idx < signal_end_idx: + if a_idx == 0 or len(self.unmerged[idx]) == 0: + self.unmerged[idx] = inplace_buffer_merge( + self.unmerged[idx], + ids, + delay, + model=self.asr_model, + ) + elif len(self.unmerged[idx]) > 0 and len(longer_toks) > 1: + id_to_match = self.unmerged[idx][-1] + start = min(len(longer_ids) - len(ids), len(longer_ids) - 1) + end = -1 + for i in range(start, end, -1): + if longer_ids[i] == id_to_match: + ids = longer_ids[i + 1 :] + break + + self.unmerged[idx] = inplace_buffer_merge( + self.unmerged[idx], + ids, + delay, + model=self.asr_model, + ) + + output = [] + for idx in range(self.batch_size): + output.append(self.greedy_merge(self.unmerged[idx])) + return output + + class LongestCommonSubsequenceBatchedFrameASRRNNT(BatchedFrameASRRNNT): """ Implements a token alignment algorithm for text alignment instead of middle token alignment.