diff --git a/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py new file mode 100644 index 000000000000..dc33b3c76fb9 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py @@ -0,0 +1,696 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional, Union + +import numpy as np +import torch + +from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BatchedBeamHyps +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs +from nemo.core.utils.cuda_python_utils import ( + check_cuda_python_cuda_graphs_conditional_nodes_supported, + cu_call, + run_nvrtc, + with_conditional_node, +) +from nemo.utils import logging +from nemo.utils.enum import PrettyStrEnum + +try: + from cuda import cudart + + HAVE_CUDA_PYTHON = True +except ImportError: + HAVE_CUDA_PYTHON = False + + +NON_EXISTENT_LABEL_VALUE = -1 +INACTIVE_SCORE = float("-inf") + + +class BacthedBeamCTCState: + """ + State for Batched Beam Search for CTC models. Used only with CUDA graphs. + In initialization phase it is possible to assign values (tensors) to the state. + For algorithm code the storage should be reused (prefer copy data instead of assigning tensors). + """ + + max_time: int # maximum length of internal storage for time dimension + batch_size: int # (maximum) length of internal storage for batch dimension + device: torch.device # device to store preallocated tensors + beam_size: int # (maximum) length of internal storage for beam dimension + blank_index: int # the index of the blank token + + decoder_outputs: torch.Tensor # logprobs from decoder + decoder_output_lengths: torch.Tensor # lengths of the decoder outputs (i.e. max time for each utterance) + last_timesteps: torch.Tensor # last time step for each utterance (used to check if the decoding is finished) + + vocab: torch.Tensor # vocabulary of the model. Constant + vocab_blank_mask: torch.Tensor # mask for blank token in the vocabulary. Constant + + curr_frame_idx: torch.Tensor # current frame index for each utterance (used to check if the decoding is finished) + active_mask: torch.Tensor # mask for active hypotheses (the decoding is finished for the utterance if it is False) + active_mask_any: torch.Tensor # 0-dim bool tensor, condition for outer loop ('any element is still active') + + batched_hyps: BatchedBeamHyps # batched hypotheses - decoding result + + # NGramGPULM-related fields + ngram_lm_batch: Optional[NGramGPULanguageModel] = None # N-gram LM for hypotheses + batch_lm_states: Optional[torch.Tensor] = None # LM states for hypotheses + batch_lm_states_candidates: Optional[torch.Tensor] = None # LM states for hypotheses candidates + + def __init__( + self, + batch_size: int, + beam_size: int, + max_time: int, + vocab_size: int, + device: torch.device, + float_dtype: torch.dtype, + blank_index: int, + ): + """ + Args: + batch_size: batch size for encoder output storage + beam_size: beam size for decoder output storage + max_time: maximum time for encoder output storage + vocab_size: vocabulary size of the model including blank + device: device to store tensors + float_dtype: default float dtype for tensors (should match projected encoder output) + blank_index: index of the blank symbol + """ + + self.device = device + self.float_dtype = float_dtype + self.batch_size = batch_size + self.beam_size = beam_size + self.max_time = max_time + self.blank_index = blank_index + self.vocab_size = vocab_size + + self.NON_EXISTENT_LABEL = torch.tensor(NON_EXISTENT_LABEL_VALUE, device=self.device, dtype=torch.long) + self.BLANK_TENSOR = torch.tensor(self.blank_index, device=self.device, dtype=torch.long) + self.INACTIVE_SCORE = torch.tensor(INACTIVE_SCORE, device=self.device, dtype=float_dtype) + + self.decoder_outputs = torch.zeros( + (self.batch_size, self.max_time, self.vocab_size), + dtype=float_dtype, + device=self.device, + ) + self.decoder_output_lengths = torch.zeros( + (self.batch_size, self.beam_size), dtype=torch.long, device=self.device + ) + self.last_timesteps = torch.zeros((self.batch_size, self.beam_size), dtype=torch.long, device=self.device) + + self.vocab = torch.arange(self.vocab_size, device=self.device, dtype=torch.long) + self.vocab_blank_mask = torch.eq(self.vocab, self.blank_index) + + self.curr_frame_idx = torch.zeros([self.beam_size], device=self.device, dtype=torch.long) + self.active_mask = torch.zeros((batch_size, self.beam_size), device=self.device, dtype=torch.bool) + self.active_mask_any = torch.tensor(True, device=self.device, dtype=torch.bool) + + self.batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self.blank_index, + init_length=max_time + 1, + device=device, + float_dtype=float_dtype, + model_type='ctc', + ) + + def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: + """Check if need to reinit state: larger batch_size/max_time, or new device""" + return ( + self.batch_size < encoder_output_projected.shape[0] + or self.max_time < encoder_output_projected.shape[1] + or self.device.index != encoder_output_projected.device.index + ) + + +@dataclass +class SeparateGraphsBatchedBeamCTC: + """Class to store Cuda graphs for decoding when separate graphs are used""" + + _before_process_batch: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + _process_batch: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + _after_process_batch: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + + +class BatchedBeamCTCComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): + """ + Batched beam search implementation for CTC models. + """ + + INITIAL_MAX_TIME = 375 # initial max time, used to init state for Cuda graphs + CUDA_PROGRAM_NAME = b"while_beam_batch_conditional_ctc.cu" + + class CudaGraphsMode(PrettyStrEnum): + FULL_GRAPH = "full_graph" # Cuda graphs with conditional nodes, fastest implementation + NO_WHILE_LOOPS = "no_while_loops" # Decoding with PyTorch while loops + partial Cuda graphs + NO_GRAPHS = "no_graphs" # decoding without graphs, stateful implementation, only for testing purposes + + separate_graphs: Optional[SeparateGraphsBatchedBeamCTC] + full_graph: Optional[torch.cuda.CUDAGraph] + cuda_graphs_mode: Optional[CudaGraphsMode] + state: Optional[BacthedBeamCTCState] + ngram_lm_batch: Optional[NGramGPULanguageModel] + + def __init__( + self, + blank_index: int, + beam_size: int, + return_best_hypothesis: bool = True, + preserve_alignments=False, + compute_timestamps: bool = False, + ngram_lm_alpha: float = 1.0, + beam_beta: float = 0.0, + beam_threshold: float = 20.0, + ngram_lm_model: str = None, + allow_cuda_graphs: bool = True, + ): + """ + Init method. + Args: + blank_index: index of blank symbol. + beam_size: beam size. + return_best_hypothesis: whether to return the best hypothesis or N-best hypotheses. + preserve_alignments: if alignments are needed. Defaults to False. + compute_timestamps: if timestamps are needed. Defaults to False. + ngram_lm_model: path to the NGPU-LM n-gram LM model: .arpa or .nemo formats. + ngram_lm_alpha: weight for the n-gram LM scores. + beam_beta: word insertion weight. + beam_threshold: threshold for pruning candidates. + allow_cuda_graphs: whether to allow CUDA graphs. Defaults to True. + """ + + super().__init__() + self._blank_index = blank_index + + self.beam_size = beam_size + self.preserve_alignments = preserve_alignments + self.compute_timestamps = compute_timestamps + self.allow_cuda_graphs = allow_cuda_graphs + self.return_best_hypothesis = return_best_hypothesis + + self.ngram_lm_alpha = ngram_lm_alpha + self.beam_beta = beam_beta + self.beam_threshold = beam_threshold + + assert not self.preserve_alignments, "Preserve aligments is not supported" + + self.state = None + self.full_graph = None + self.separate_graphs = None + + self.cuda_graphs_mode = None + self.maybe_enable_cuda_graphs() + + self.ngram_lm_batch = None + if ngram_lm_model is not None: + assert self._blank_index != 0, "Blank should not be the first token in the vocabulary" + self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) + + def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): + """ + Method to set graphs mode. Use only for testing purposes. + For debugging the algorithm use "no_graphs" mode, since it is impossible to debug CUDA graphs directly. + """ + self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None + self.state = None + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if conditions met""" + if self.cuda_graphs_mode is not None: + # CUDA graphs are already enabled + return + + if not self.allow_cuda_graphs: + self.cuda_graphs_mode = None + else: + # cuda graphs are allowed + # check while loops + try: + check_cuda_python_cuda_graphs_conditional_nodes_supported() + self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH + except (ImportError, ModuleNotFoundError, EnvironmentError) as e: + logging.warning( + "No conditional node support for Cuda.\n" + "Cuda graphs with while loops are disabled, decoding speed will be slower\n" + f"Reason: {e}" + ) + self.cuda_graphs_mode = self.CudaGraphsMode.NO_GRAPHS + self.reset_cuda_graphs_state() + + def disable_cuda_graphs(self): + """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" + if self.cuda_graphs_mode is None: + # nothing to disable + return + self.cuda_graphs_mode = None + self.reset_cuda_graphs_state() + + def reset_cuda_graphs_state(self): + """Reset state to release memory (for CUDA graphs implementations)""" + self.state = None + self.full_graph = None + self.separate_graphs = None + + @torch.no_grad() + def batched_beam_search_torch( + self, decoder_outputs: torch.Tensor, decoder_output_lengths: torch.Tensor + ) -> BatchedBeamHyps: + """ + Pure PyTorch implementation of the batched beam search algorithm. + + Args: + decoder_outputs (torch.Tensor): Tensor of shape [B, T, V+1], where B is the batch size, + T is the maximum sequence length, and V is the vocabulary size. The tensor contains log-probabilities. + decoder_output_lengths (torch.Tensor): Tensor of shape [B], contains lengths of each sequence in the batch. + Returns: + A list of NBestHypotheses objects, one for each sequence in the batch. + """ + + curr_batch_size, curr_max_time, vocab_size = decoder_outputs.shape + + vocab = torch.arange(vocab_size, device=decoder_outputs.device, dtype=torch.long) + vocab_blank_mask = vocab == self._blank_index + + batched_beam_hyps = BatchedBeamHyps( + batch_size=curr_batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=curr_max_time + 1, + device=decoder_outputs.device, + float_dtype=decoder_outputs.dtype, + model_type='ctc', + ) + + if self.ngram_lm_batch is not None: + self.ngram_lm_batch.to(decoder_outputs.device) + batch_lm_states = self.ngram_lm_batch.get_init_states( + batch_size=curr_batch_size * self.beam_size, bos=True + ) + + for frame_idx in range(curr_max_time): + active_mask = frame_idx < decoder_output_lengths.unsqueeze(1) + repeated_mask = batched_beam_hyps.last_label[:, :, None] == vocab[None, None, :] + repeated_or_blank_mask = repeated_mask | vocab_blank_mask[None, None, :] + + # step 2.1: getting the log probs and updating with LM scores + log_probs = decoder_outputs[:, frame_idx, :].unsqueeze(1).repeat(1, self.beam_size, 1) + log_probs += batched_beam_hyps.scores.unsqueeze(-1) + + # step 2.2: updating non-blank and non-repeating token scores with `beam_beta` + log_probs = torch.where(repeated_or_blank_mask, log_probs, log_probs + self.beam_beta) + + if self.ngram_lm_batch is not None: + lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states.view(-1)) + lm_scores = torch.where( + repeated_mask[..., :-1], 0, lm_scores.view(curr_batch_size, self.beam_size, -1) + ) + log_probs[..., :-1] += self.ngram_lm_alpha * lm_scores.view(curr_batch_size, self.beam_size, -1) + + # step 2.3: getting `beam_size` best candidates + next_scores, next_candidates_indices = torch.topk( + log_probs.view(curr_batch_size, -1), k=self.beam_size, largest=True, sorted=True + ) + next_indices = next_candidates_indices // vocab_size + next_labels = next_candidates_indices % vocab_size + + # step 2.3: pruning candidates with threshold `beam_threshold` + batch_next_scores = next_scores.view(curr_batch_size, -1) + max_next_score = batch_next_scores.max(dim=-1, keepdim=True).values + batch_next_scores.masked_fill_(batch_next_scores <= max_next_score - self.beam_threshold, INACTIVE_SCORE) + next_scores.view(curr_batch_size, self.beam_size, -1) + + # step 2.4: preserving updated lm states + if self.ngram_lm_batch is not None: + last_labels = torch.gather(batched_beam_hyps.last_label, dim=-1, index=next_indices) + blank_mask = next_labels == self._blank_index + repeating_mask = next_labels == last_labels + preserve_state_mask = repeating_mask | blank_mask | ~active_mask + + # step 2.4.1: masking blanks and inactive labels to pass to LM, as LM does not support blanks + next_labels_masked = torch.where(blank_mask, 0, next_labels) + + # step 2.4.2: gathering LM states of extended hypotheses + # batch_lm_states: [(BxBeam)] + # batch_lm_states_candidates: [(BxBeam) x V (without blank)] + next_indices_extended = next_indices[:, :, None].expand( + curr_batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] + ) + batch_lm_states_candidates = batch_lm_states_candidates.view(curr_batch_size, self.beam_size, -1) + batch_lm_states_candidates = torch.gather( + batch_lm_states_candidates, dim=1, index=next_indices_extended + ) + batch_lm_states_prev = torch.gather( + batch_lm_states.view(curr_batch_size, self.beam_size), dim=1, index=next_indices + ) + batch_lm_states = torch.gather( + batch_lm_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) + ).squeeze(-1) + + batch_lm_states = torch.where(preserve_state_mask, batch_lm_states_prev, batch_lm_states).view(-1) + + # step 2.5: masking inactive hypotheses, updating + recombining batched beam hypoteses + next_labels = torch.where(active_mask, next_labels, NON_EXISTENT_LABEL_VALUE) + batched_beam_hyps.add_results_(next_indices, next_labels, next_scores) + batched_beam_hyps.recombine_hyps_() + + # step 3: updating LM scores with eos scores + if self.ngram_lm_batch is not None: + eos_score = self.ngram_lm_batch.get_final(batch_lm_states).view(batched_beam_hyps.scores.shape) + batched_beam_hyps.scores += eos_score * self.ngram_lm_alpha + + return batched_beam_hyps + + def batched_beam_search_cuda_graphs( + self, + decoder_outputs: torch.Tensor, + decoder_output_lengths: torch.Tensor, + ) -> BatchedBeamHyps: + """ + Cuda-Graphs implementation of the batched beam search algorithm. + + Args: + decoder_outputs (torch.Tensor): Tensor of shape [B, T, V+1], where B is the batch size, + T is the maximum sequence length, and V is the vocabulary size. The tensor contains log-probabilities. + decoder_output_lengths (torch.Tensor): Tensor of shape [B], contains lengths of each sequence in the batch. + Returns: + A list of NBestHypotheses objects, one for each sequence in the batch. + """ + + assert self.cuda_graphs_mode is not None + + curr_batch_size, curr_max_time, _ = decoder_outputs.shape + + if torch.is_autocast_enabled(): + decoder_outputs = decoder_outputs.to(torch.get_autocast_gpu_dtype()) + + # init or reinit graph + if self.state is None or self.state.need_reinit(decoder_outputs): + self._graph_reinitialize(decoder_outputs, decoder_output_lengths) + + # set length to zero for elements outside the current batch + self.state.decoder_output_lengths.fill_(0) + # copy (projected) encoder output and lenghts + self.state.decoder_outputs[:curr_batch_size, :curr_max_time, ...].copy_(decoder_outputs) + self.state.decoder_output_lengths[:curr_batch_size].copy_(decoder_output_lengths.unsqueeze(-1)) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self.full_graph.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self.separate_graphs._before_process_batch.replay() + while self.state.active_mask_any.item(): + self.separate_graphs._process_batch.replay() + self.separate_graphs._after_process_batch.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # this mode is only for testing purposes + # manual loop instead of using graphs + self._before_process_batch() + while self.state.active_mask_any.item(): + self._process_batch() + self._after_process_batch() + else: + raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") + + return self.state.batched_hyps + + @classmethod + def _create_process_batch_kernel(cls): + """ + Creates a kernel that evaluates whether to enter the outer loop body (not all hypotheses are decoded). + Condition: while(active_mask_any). + """ + kernel_string = r"""\ + typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; + + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); + + extern "C" __global__ + void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any) + { + cudaGraphSetConditional(handle, *active_mask_any); + } + """ + return run_nvrtc(kernel_string, b"loop_conditional", cls.CUDA_PROGRAM_NAME) + + def _graph_reinitialize( + self, + decoder_outputs: torch.Tensor, + decoder_output_lengths: torch.Tensor, + ): + """ + Reinitializes the graph state for the Beam Search computation. + This method sets up the internal state required for the decoding process, including initializing + decoder outputs, decoder states, and optional n-gram language model states. It also handles CUDA + graph compilation based on the specified mode. + Args: + encoder_output_projected (torch.Tensor): The projected encoder output tensor of shape + (batch_size, max_time, encoder_dim). + encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch. + Raises: + NotImplementedError: If an unsupported CUDA graph mode is specified. + """ + + batch_size, max_time, vocab_size = decoder_outputs.shape + + self.state = BacthedBeamCTCState( + batch_size=batch_size, + beam_size=self.beam_size, + max_time=max(max_time, self.INITIAL_MAX_TIME), + vocab_size=vocab_size, + device=decoder_outputs.device, + float_dtype=decoder_outputs.dtype, + blank_index=self._blank_index, + ) + + if self.ngram_lm_batch is not None: + device = decoder_outputs.device + + self.ngram_lm_batch.to(device) + + batch_lm_states = self.ngram_lm_batch.get_init_states( + batch_size=self.state.batch_size * self.beam_size, bos=True + ) + self.state.batch_lm_states = batch_lm_states.view(self.state.batch_size, self.beam_size) + + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self._full_graph_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self._partial_graphs_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # no graphs needed + pass + else: + raise NotImplementedError + + def _partial_graphs_compile(self): + """Compile decoding by parts""" + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. + stream_for_graph = torch.cuda.Stream(self.state.device) + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) + self.separate_graphs = SeparateGraphsBatchedBeamCTC() + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs._before_process_batch, stream=stream_for_graph, capture_error_mode="thread_local" + ), + ): + self._before_process_batch() + + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs._process_batch, stream=stream_for_graph, capture_error_mode="thread_local" + ), + ): + self._process_batch() + + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs._after_process_batch, stream=stream_for_graph, capture_error_mode="thread_local" + ), + ): + self._after_process_batch() + + def _full_graph_compile(self): + """Compile full graph for decoding""" + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. + stream_for_graph = torch.cuda.Stream(self.state.device) + self.full_graph = torch.cuda.CUDAGraph() + + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), + ): + self._before_process_batch() + capture_status, _, graph, _, _ = cu_call( + cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) + ) + + assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive + + # capture: while self.active_mask_any: + (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + loop_kernel = self._create_process_batch_kernel() + active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) + loop_args = np.array( + [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + dtype=np.uint64, + ) + # loop while there are active utterances + with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + self._process_batch() + + self._after_process_batch() + + def _before_process_batch(self): + """ + Clears state and setups LM. + """ + # step 1.1: reset state + self.state.batched_hyps.clear_() + self.state.curr_frame_idx.fill_(0) + + # maximum time step for each utterance + torch.sub(self.state.decoder_output_lengths, 1, out=self.state.last_timesteps) + + # masks for utterances in batch + # same as: active_mask = self.encoder_output_length > 0 + torch.greater(self.state.decoder_output_lengths, 0, out=self.state.active_mask) + + # same as: self.active_mask_any = active_mask.any() + torch.any(self.state.active_mask, out=self.state.active_mask_any) + + # step 1.2: setup LM + if self.ngram_lm_batch is not None: + device = self.state.device + self.ngram_lm_batch.to(device) + + batch_lm_states = self.ngram_lm_batch.get_init_states( + batch_size=self.state.batch_size * self.beam_size, bos=True + ) + self.state.batch_lm_states.copy_(batch_lm_states.view(self.state.batch_size, self.beam_size)) + self.state.batch_lm_states_candidates = torch.empty( + (self.state.batch_size, self.state.beam_size, self.ngram_lm_batch.vocab_size), + device=device, + dtype=torch.long, + ) + + def _process_batch(self): + """ + Performs a decoding step. + """ + repeated_mask = self.state.batched_hyps.last_label[:, :, None] == self.state.vocab[None, None, :] + repeated_or_blank_mask = repeated_mask | self.state.vocab_blank_mask[None, None, :] + + # step 2.1: getting the log probs and updating with LM scores + log_probs = self.state.decoder_outputs.index_select(dim=1, index=self.state.curr_frame_idx) + log_probs += self.state.batched_hyps.scores[:, :, None] + + # step 2.2: updating non-blank and non-repeating token scores with `beam_beta` + log_probs = torch.where(repeated_or_blank_mask, log_probs, log_probs + self.beam_beta) + + if self.ngram_lm_batch is not None: + lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + states=self.state.batch_lm_states.view(-1) + ) + lm_scores = torch.where(repeated_mask[..., :-1], 0, lm_scores.view(log_probs.shape[0], self.beam_size, -1)) + + self.state.batch_lm_states_candidates.copy_( + batch_lm_states_candidates.view(self.state.batch_lm_states_candidates.shape) + ) + log_probs[..., :-1] += self.ngram_lm_alpha * lm_scores.view( + self.state.batch_size, self.state.beam_size, -1 + ) + + # step 2.3: getting `beam_size` best candidates + next_scores, next_candidates_indices = torch.topk( + log_probs.view(self.state.batch_size, -1), k=self.beam_size, largest=True, sorted=True + ) + next_indices = next_candidates_indices // self.state.vocab_size + next_labels = next_candidates_indices % self.state.vocab_size + + # step 2.3: pruning candidates with threshold `beam_threshold` + batch_next_scores = next_scores.view(self.state.batch_size, -1) + max_next_score = batch_next_scores.max(dim=-1, keepdim=True).values + batch_next_scores.masked_fill_(batch_next_scores <= max_next_score - self.beam_threshold, INACTIVE_SCORE) + next_scores.view(self.state.batch_size, self.beam_size, -1) + + # step 2.4: preserving updated lm states + if self.ngram_lm_batch is not None: + last_labels = torch.gather(self.state.batched_hyps.last_label, dim=-1, index=next_indices) + blank_mask = next_labels == self._blank_index + repeating_mask = next_labels == last_labels + preserve_state_mask = repeating_mask | blank_mask | ~self.state.active_mask + + # step 2.4.1: masking blanks and inactive labels to pass to LM, as LM does not support blanks + next_labels_masked = torch.where(blank_mask, 0, next_labels) + + # step 2.4.2: gathering LM states of extended hypotheses + # batch_lm_states: [(BxBeam)] + # batch_lm_states_candidates: [(BxBeam) x V (without blank)] + next_indices_extended = next_indices[:, :, None].expand(self.state.batch_lm_states_candidates.shape) + batch_lm_states_candidates = torch.gather( + self.state.batch_lm_states_candidates, dim=1, index=next_indices_extended + ) + batch_lm_states_prev = torch.gather(self.state.batch_lm_states, dim=1, index=next_indices) + batch_lm_states = torch.gather( + batch_lm_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) + ).squeeze() + + # step 2.4.3: update LM states in State + self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) + torch.where(preserve_state_mask, batch_lm_states_prev, batch_lm_states, out=self.state.batch_lm_states) + + # step 2.5: masking inactive hypotheses, updating + recombining batched beam hypoteses + torch.where(self.state.active_mask, next_labels, self.state.NON_EXISTENT_LABEL, out=next_labels) + self.state.batched_hyps.add_results_no_checks_(next_indices, next_labels, next_scores) + self.state.batched_hyps.recombine_hyps_() + + # step 2.6: updating frame idx and active masks + self.state.curr_frame_idx.add_(1) + torch.greater_equal(self.state.last_timesteps, self.state.curr_frame_idx, out=self.state.active_mask) + torch.any(self.state.active_mask, out=self.state.active_mask_any) + + def _after_process_batch(self): + """ + Finalizes the decoding process by updating the LM scores with the end-of-sequence (eos) scores. + """ + # step 3: updating LM scores with eos scores + if self.ngram_lm_batch is not None: + eos_score = self.ngram_lm_batch.get_final(self.state.batch_lm_states).view( + self.state.batched_hyps.scores.shape + ) + self.state.batched_hyps.scores += eos_score * self.ngram_lm_alpha + + def __call__( + self, + x: torch.Tensor, + out_len: torch.Tensor, + ) -> BatchedBeamHyps: + if self.cuda_graphs_mode is not None and x.device.type == "cuda": + return self.batched_beam_search_cuda_graphs(decoder_outputs=x, decoder_output_lengths=out_len) + + return self.batched_beam_search_torch(decoder_outputs=x, decoder_output_lengths=out_len) diff --git a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py index 61f682e657f5..9467d228e331 100644 --- a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py @@ -22,6 +22,7 @@ import torch from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig +from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer from nemo.collections.asr.parts.submodules.ngram_lm import DEFAULT_TOKEN_OFFSET from nemo.collections.asr.parts.submodules.wfst_decoder import RivaDecoderConfig, WfstNbestHypothesis from nemo.collections.asr.parts.utils import rnnt_utils @@ -204,7 +205,7 @@ def __call__(self, *args, **kwargs): class BeamCTCInfer(AbstractBeamCTCInfer): - """A greedy CTC decoder. + """A beam CTC decoder. Provides a common abstraction for sample level and batch level greedy decoding. @@ -227,9 +228,9 @@ def __init__( return_best_hypothesis: bool = True, preserve_alignments: bool = False, compute_timestamps: bool = False, - beam_alpha: float = 1.0, + ngram_lm_alpha: float = 0.3, beam_beta: float = 0.0, - kenlm_path: str = None, + ngram_lm_model: str = None, flashlight_cfg: Optional['FlashlightConfig'] = None, pyctcdecode_cfg: Optional['PyCTCDecodeConfig'] = None, ): @@ -260,11 +261,11 @@ def __init__( # Log the beam search algorithm logging.info(f"Beam search algorithm: {search_type}") - self.beam_alpha = beam_alpha + self.ngram_lm_alpha = ngram_lm_alpha self.beam_beta = beam_beta # Default beam search args - self.kenlm_path = kenlm_path + self.ngram_lm_model = ngram_lm_model # PyCTCDecode params if pyctcdecode_cfg is None: @@ -349,9 +350,9 @@ def default_beam_search( if self.default_beam_scorer is None: # Check for filepath - if self.kenlm_path is None or not os.path.exists(self.kenlm_path): + if self.ngram_lm_model is None or not os.path.exists(self.ngram_lm_model): raise FileNotFoundError( - f"KenLM binary file not found at : {self.kenlm_path}. " + f"KenLM binary file not found at : {self.ngram_lm_model}. " f"Please set a valid path in the decoding config." ) @@ -367,9 +368,9 @@ def default_beam_search( self.default_beam_scorer = BeamSearchDecoderWithLM( vocab=vocab, - lm_path=self.kenlm_path, + lm_path=self.ngram_lm_model, beam_width=self.beam_size, - alpha=self.beam_alpha, + alpha=self.ngram_lm_alpha, beta=self.beam_beta, num_cpus=max(1, os.cpu_count()), input_tensor=False, @@ -451,7 +452,7 @@ def _pyctcdecode_beam_search( if self.pyctcdecode_beam_scorer is None: self.pyctcdecode_beam_scorer = pyctcdecode.build_ctcdecoder( - labels=self.vocab, kenlm_model_path=self.kenlm_path, alpha=self.beam_alpha, beta=self.beam_beta + labels=self.vocab, kenlm_model_path=self.ngram_lm_model, alpha=self.ngram_lm_alpha, beta=self.beam_beta ) # type: pyctcdecode.BeamSearchDecoderCTC x = x.to('cpu').numpy() @@ -533,9 +534,9 @@ def flashlight_beam_search( if self.flashlight_beam_scorer is None: # Check for filepath - if self.kenlm_path is None or not os.path.exists(self.kenlm_path): + if self.ngram_lm_model is None or not os.path.exists(self.ngram_lm_model): raise FileNotFoundError( - f"KenLM binary file not found at : {self.kenlm_path}. " + f"KenLM binary file not found at : {self.ngram_lm_model}. " "Please set a valid path in the decoding config." ) @@ -550,7 +551,7 @@ def flashlight_beam_search( from nemo.collections.asr.modules.flashlight_decoder import FlashLightKenLMBeamSearchDecoder self.flashlight_beam_scorer = FlashLightKenLMBeamSearchDecoder( - lm_path=self.kenlm_path, + lm_path=self.ngram_lm_model, vocabulary=self.vocab, tokenizer=self.tokenizer, lexicon_path=self.flashlight_cfg.lexicon_path, @@ -558,7 +559,7 @@ def flashlight_beam_search( beam_size=self.beam_size, beam_size_token=self.flashlight_cfg.beam_size_token, beam_threshold=self.flashlight_cfg.beam_threshold, - lm_weight=self.beam_alpha, + lm_weight=self.ngram_lm_alpha, word_score=self.beam_beta, unk_weight=self.flashlight_cfg.unk_weight, sil_weight=self.flashlight_cfg.sil_weight, @@ -877,6 +878,108 @@ def _k2_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbes return self.k2_decoder.decode(x.to(device=self.device), out_len.to(device=self.device)) +class BeamBatchedCTCInfer(AbstractBeamCTCInfer): + """ + A batched beam CTC decoder. + + Args: + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + beam_size: int size of the beam. + return_best_hypothesis: When set to True (default), returns a single Hypothesis. + When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. + preserve_alignments: Bool flag which preserves the history of logprobs generated during + decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors. + compute_timestamps: A bool flag, which determines whether to compute the character/subword, or + word based timestamp mapping the output log-probabilities to discrite intervals of timestamps. + The timestamps will be available in the returned Hypothesis.timestep as a dictionary. + ngram_lm_alpha: float, the language model weight. + beam_beta: float, the word insertion weight. + beam_threshold: float, the beam pruning threshold. + ngram_lm_model: str, the path to the ngram model. + allow_cuda_graphs: bool, whether to allow cuda graphs for the beam search algorithm. + """ + + def __init__( + self, + blank_index: int, + beam_size: int, + return_best_hypothesis: bool = True, + preserve_alignments: bool = False, + compute_timestamps: bool = False, + ngram_lm_alpha: float = 1.0, + beam_beta: float = 0.0, + beam_threshold: float = 20.0, + ngram_lm_model: str = None, + allow_cuda_graphs: bool = True, + ): + super().__init__(blank_id=blank_index, beam_size=beam_size) + + self.return_best_hypothesis = return_best_hypothesis + self.preserve_alignments = preserve_alignments + self.compute_timestamps = compute_timestamps + self.allow_cuda_graphs = allow_cuda_graphs + + if self.compute_timestamps: + raise ValueError("`Compute timestamps` is not supported for batched beam search.") + if self.preserve_alignments: + raise ValueError("`Preserve alignments` is not supported for batched beam search.") + + self.ngram_lm_alpha = ngram_lm_alpha + self.beam_beta = beam_beta + self.beam_threshold = beam_threshold + + # Default beam search args + self.ngram_lm_model = ngram_lm_model + + self.search_algorithm = BatchedBeamCTCComputer( + blank_index=blank_index, + beam_size=beam_size, + return_best_hypothesis=return_best_hypothesis, + preserve_alignments=preserve_alignments, + compute_timestamps=compute_timestamps, + ngram_lm_alpha=ngram_lm_alpha, + beam_beta=beam_beta, + beam_threshold=beam_threshold, + ngram_lm_model=ngram_lm_model, + allow_cuda_graphs=allow_cuda_graphs, + ) + + @typecheck() + def forward( + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, + ) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]: + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + with torch.no_grad(), torch.inference_mode(): + if decoder_output.ndim != 3: + raise ValueError( + f"`decoder_output` must be a tensor of shape [B, T, V] (log probs, float). " + f"Provided shape = {decoder_output.shape}" + ) + + batched_beam_hyps = self.search_algorithm(decoder_output, decoder_lengths) + + batch_size = decoder_lengths.shape[0] + if self.return_best_hypothesis: + hyps = batched_beam_hyps.to_hyps_list(score_norm=False)[:batch_size] + else: + hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=False)[:batch_size] + + return (hyps,) + + @dataclass class PyCTCDecodeConfig: # These arguments cannot be imported from pyctcdecode (optional dependency) @@ -906,10 +1009,14 @@ class BeamCTCInferConfig: preserve_alignments: bool = False compute_timestamps: bool = False return_best_hypothesis: bool = True + allow_cuda_graphs: bool = True - beam_alpha: float = 1.0 - beam_beta: float = 0.0 - kenlm_path: Optional[str] = None + beam_alpha: Optional[float] = None # Deprecated + beam_beta: float = 1.0 + beam_threshold: float = 20.0 + kenlm_path: Optional[str] = None # Deprecated, default should be None + ngram_lm_alpha: Optional[float] = 1.0 + ngram_lm_model: Optional[str] = None flashlight_cfg: Optional[FlashlightConfig] = field(default_factory=lambda: FlashlightConfig()) pyctcdecode_cfg: Optional[PyCTCDecodeConfig] = field(default_factory=lambda: PyCTCDecodeConfig()) diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index b7b824a1dc90..2e08205f3bb3 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -178,15 +178,15 @@ class AbstractCTCDecoding(ConfidenceMixin): optional bool, whether to return just the best hypothesis or all of the hypotheses after beam search has concluded. This flag is set by default. - beam_alpha: + ngram_lm_alpha: float, the strength of the Language model on the final score of a token. - final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length. + final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length. beam_beta: float, the strength of the sequence length penalty on the final score of a token. - final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length. + final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length. - kenlm_path: + ngram_lm_model: str, path to a KenLM ARPA or .binary file (depending on the strategy chosen). If the path is invalid (file is not found at path), will raise a deferred error at the moment of calculation of beam search, so that users may update / change the decoding strategy @@ -226,7 +226,7 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ self.segment_seperators = self.cfg.get('segment_seperators', ['.', '?', '!']) self.segment_gap_threshold = self.cfg.get('segment_gap_threshold', None) - possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst'] + possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst', 'beam_batch'] if self.cfg.strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}") @@ -267,6 +267,20 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ if self.compute_timestamps is not None: self.compute_timestamps |= self.preserve_frame_confidence + if self.cfg.strategy in ['flashlight', 'wfst', 'beam_batch', 'pyctcdecode', 'beam']: + if self.cfg.beam.beam_alpha is not None: + logging.warning( + "`beam_alpha` is deprecated and will be removed in a future release. " + "Please use `ngram_lm_alpha` instead." + ) + self.cfg.beam.ngram_lm_alpha = self.cfg.beam.beam_alpha + if self.cfg.beam.kenlm_path is not None: + logging.warning( + "`kenlm_path` is deprecated and will be removed in a future release. " + "Please use `ngram_lm_model` instead." + ) + self.cfg.beam.ngram_lm_model = self.cfg.beam.kenlm_path + if self.cfg.strategy == 'greedy': self.decoding = ctc_greedy_decoding.GreedyCTCInfer( blank_id=self.blank_id, @@ -294,9 +308,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), preserve_alignments=self.preserve_alignments, compute_timestamps=self.compute_timestamps, - beam_alpha=self.cfg.beam.get('beam_alpha', 1.0), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0), beam_beta=self.cfg.beam.get('beam_beta', 0.0), - kenlm_path=self.cfg.beam.get('kenlm_path', None), + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), ) self.decoding.override_fold_consecutive_value = False @@ -310,9 +324,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), preserve_alignments=self.preserve_alignments, compute_timestamps=self.compute_timestamps, - beam_alpha=self.cfg.beam.get('beam_alpha', 1.0), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0), beam_beta=self.cfg.beam.get('beam_beta', 0.0), - kenlm_path=self.cfg.beam.get('kenlm_path', None), + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), pyctcdecode_cfg=self.cfg.beam.get('pyctcdecode_cfg', None), ) @@ -327,9 +341,9 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), preserve_alignments=self.preserve_alignments, compute_timestamps=self.compute_timestamps, - beam_alpha=self.cfg.beam.get('beam_alpha', 1.0), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0), beam_beta=self.cfg.beam.get('beam_beta', 0.0), - kenlm_path=self.cfg.beam.get('kenlm_path', None), + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), flashlight_cfg=self.cfg.beam.get('flashlight_cfg', None), ) @@ -357,6 +371,22 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ self.decoding.override_fold_consecutive_value = False + elif self.cfg.strategy == 'beam_batch': + self.decoding = ctc_beam_decoding.BeamBatchedCTCInfer( + blank_index=blank_id, + beam_size=self.cfg.beam.get('beam_size', 1), + return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), + preserve_alignments=self.preserve_alignments, + compute_timestamps=self.compute_timestamps, + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 1.0), + beam_beta=self.cfg.beam.get('beam_beta', 0.0), + beam_threshold=self.cfg.beam.get('beam_threshold', 20.0), + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + allow_cuda_graphs=self.cfg.beam.get('allow_cuda_graphs', True), + ) + + self.decoding.override_fold_consecutive_value = False + else: raise ValueError( f"Incorrect decoding strategy supplied. Must be one of {possible_strategies}\n" @@ -1051,15 +1081,15 @@ class CTCDecoding(AbstractCTCDecoding): optional bool, whether to return just the best hypothesis or all of the hypotheses after beam search has concluded. This flag is set by default. - beam_alpha: + ngram_lm_alpha: float, the strength of the Language model on the final score of a token. - final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length. + final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length. beam_beta: float, the strength of the sequence length penalty on the final score of a token. - final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length. + final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length. - kenlm_path: + ngram_lm_model: str, path to a KenLM ARPA or .binary file (depending on the strategy chosen). If the path is invalid (file is not found at path), will raise a deferred error at the moment of calculation of beam search, so that users may update / change the decoding strategy @@ -1340,15 +1370,15 @@ class CTCBPEDecoding(AbstractCTCDecoding): optional bool, whether to return just the best hypothesis or all of the hypotheses after beam search has concluded. This flag is set by default. - beam_alpha: + ngram_lm_alpha: float, the strength of the Language model on the final score of a token. - final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length. + final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length. beam_beta: float, the strength of the sequence length penalty on the final score of a token. - final_score = acoustic_score + beam_alpha * lm_score + beam_beta * seq_length. + final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length. - kenlm_path: + ngram_lm_model: str, path to a KenLM ARPA or .binary file (depending on the strategy chosen). If the path is invalid (file is not found at path), will raise a deferred error at the moment of calculation of beam search, so that users may update / change the decoding strategy diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 11556b58d7aa..1f6cb6015327 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -40,7 +40,7 @@ from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import BlankLMScoreMode, PruningMode +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BlankLMScoreMode, PruningMode from nemo.collections.asr.parts.utils.rnnt_utils import ( HATJointOutput, Hypothesis, diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 0c6c43060eb4..6034f4833743 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -25,7 +25,7 @@ from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding, tdt_beam_decoding from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, ConfidenceMixin -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import BlankLMScoreMode, PruningMode +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BlankLMScoreMode, PruningMode from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec diff --git a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py index c1827805a4bf..e74e60aecffc 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py @@ -18,7 +18,7 @@ from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import ( +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, NON_EXISTENT_LABEL_VALUE, BatchedBeamHyps, diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 12d817c6c597..b00fad5a9e09 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -21,7 +21,7 @@ from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import ( +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, NON_EXISTENT_LABEL_VALUE, BatchedBeamHyps, @@ -508,7 +508,7 @@ def modified_alsd_torch( batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. - batched_hyps.recombine_hyps() + batched_hyps.recombine_hyps_() # step 5: update decoder state + decoder output (+ lm state/scores) # step 5.1: mask invalid value labels with blank to avoid errors (refer to step 2.2) @@ -1008,7 +1008,7 @@ def _loop_body(self): ) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. - self.state.batched_hyps.recombine_hyps() + self.state.batched_hyps.recombine_hyps_() def _loop_update_decoder(self): """ diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index 61160018e138..f98ef888e7fa 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -37,7 +37,7 @@ from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import pack_hypotheses from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import BlankLMScoreMode, PruningMode +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BlankLMScoreMode, PruningMode from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index 3587165803b4..5bf43dfc938d 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -21,7 +21,7 @@ from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import ( +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, NON_EXISTENT_LABEL_VALUE, BatchedBeamHyps, @@ -193,7 +193,7 @@ def __init__( init_length=max_time * (max_symbols + 1) if max_symbols is not None else max_time, device=device, float_dtype=float_dtype, - is_tdt=True, + model_type='tdt', ) def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: @@ -390,7 +390,7 @@ def modified_alsd_torch( init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, device=device, float_dtype=float_dtype, - is_tdt=True, + model_type='tdt', ) last_labels_wb = torch.full( @@ -543,7 +543,7 @@ def modified_alsd_torch( batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. - batched_hyps.recombine_hyps() + batched_hyps.recombine_hyps_() # step 5: update decoder state + decoder output (+ lm state/scores) # step 5.1: mask invalid value labels with blank to avoid errors (refer to step 2.2) @@ -1130,7 +1130,7 @@ def _loop_body(self): ) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. - self.state.batched_hyps.recombine_hyps() + self.state.batched_hyps.recombine_hyps_() def _loop_update_decoder(self): """ diff --git a/nemo/collections/asr/parts/utils/rnnt_batched_beam_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py similarity index 76% rename from nemo/collections/asr/parts/utils/rnnt_batched_beam_utils.py rename to nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 627ae979a70f..e0af3b7623d7 100644 --- a/nemo/collections/asr/parts/utils/rnnt_batched_beam_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -67,6 +67,14 @@ class PruningMode(PrettyStrEnum): LATE = "late" +class ASRModelTypeEnum(PrettyStrEnum): + """Specifies model type.""" + + RNNT = "rnnt" + TDT = "tdt" + CTC = "ctc" + + class BatchedBeamHyps: """Class to store batch of beam hypotheses (labels, time_indices, scores) for efficient batched beam decoding""" @@ -79,7 +87,7 @@ def __init__( device: torch.device = None, float_dtype: torch.dtype = None, store_prefix_hashes: Optional[bool] = False, - is_tdt: Optional[bool] = False, + model_type: Optional[ASRModelTypeEnum | str] = ASRModelTypeEnum.RNNT, ): """ Initializes the batched beam hypotheses utility for Transducer decoding (RNN-T and TDT models). @@ -91,7 +99,7 @@ def __init__( device (torch.device): The device on which tensors will be allocated. Defaults to None. float_dtype (torch.dtype): The floating-point data type. Defaults to None. store_prefix_hashes (bool, optional): Whether to store prefix hashes for hypotheses. Defaults to False. - is_tdt: (bool, optional): Whether will be used for TDT models. Defaults to false. + model_type: (str or ModelTypeEnum, optional): Model type, either 'rnnt', 'tdt' or 'ctc'. Defaults to 'rnnt'. """ if beam_size <= 0: @@ -105,7 +113,7 @@ def __init__( self.INACTIVE_SCORE_TENSOR = torch.tensor(INACTIVE_SCORE, device=device, dtype=float_dtype) self.ZERO_TENSOR = torch.tensor(0, device=device, dtype=torch.long) - self.is_tdt = is_tdt + self.model_type = ASRModelTypeEnum(model_type) self.store_prefix_hashes = store_prefix_hashes self._max_length = init_length self.beam_size = beam_size @@ -114,7 +122,7 @@ def __init__( self.batch_indices = torch.arange(self.batch_size, device=device) self.beam_indices = torch.arange(self.beam_size, device=device) - # non-blank and full lengths + # Non-blank (non-blank and non-repeating for CTC) and full lengths self.current_lengths_nb = torch.zeros([batch_size, self.beam_size], device=device, dtype=torch.long) self.current_lengths_wb = torch.zeros([batch_size, self.beam_size], device=device, dtype=torch.long) @@ -131,9 +139,6 @@ def __init__( device=device, dtype=torch.long, ) # links to prefices - self.timestamps = torch.zeros( - (batch_size, self.beam_size, self._max_length), device=device, dtype=torch.long - ) # timestamps # Initializing beam scores: Initially, only a single hypothesis is active within the beam. self.scores = torch.full( @@ -144,8 +149,6 @@ def __init__( self.last_label = torch.full( (batch_size, self.beam_size), fill_value=NON_EXISTENT_LABEL_VALUE, device=device, dtype=torch.long ) - self.next_timestamp = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long) - self.last_timestamp_lasts = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long) self.transcript_hash = torch.full( [batch_size, self.beam_size], device=device, dtype=torch.long, fill_value=INIT_HASH_VALUE @@ -155,6 +158,19 @@ def __init__( [batch_size, self.beam_size], device=device, dtype=torch.long, fill_value=INIT_PREFIX_HASH_VALUE ) + if self.model_type == ASRModelTypeEnum.CTC: + # CTC frames and tokens are aligned, so we can precompute timestamps + self.timestamps = self._create_timestamps_tensor(self._max_length) # timestamps + else: + # timestamps for transducer models + self.timestamps = torch.zeros( + (batch_size, self.beam_size, self._max_length), device=device, dtype=torch.long + ) # timestamps + + # tracking last frame index and number of labels for the last frama + self.next_timestamp = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long) + self.last_timestamp_lasts = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long) + def clear_(self): """ Clears and resets the internal state of the object. @@ -163,21 +179,26 @@ def clear_(self): self.current_lengths_nb.fill_(0) self.current_lengths_wb.fill_(0) - self.transcript_wb.fill_(0) + self.transcript_wb.fill_(NON_EXISTENT_LABEL_VALUE) self.transcript_wb_prev_ptr.fill_(INIT_POINTER_VALUE) - self.timestamps.fill_(0) self.scores.fill_(INACTIVE_SCORE) self.scores[:, 0].fill_(0.0) self.last_label.fill_(NON_EXISTENT_LABEL_VALUE) - self.next_timestamp.fill_(0) - self.last_timestamp_lasts.fill_(0) self.transcript_hash.fill_(INIT_HASH_VALUE) if self.store_prefix_hashes: self.transcript_prefix_hash.fill_(INIT_PREFIX_HASH_VALUE) + # model specific parameters + if self.model_type == ASRModelTypeEnum.CTC: + self.timestamps.copy_(self._create_timestamps_tensor(self._max_length)) + else: + self.timestamps.fill_(0) + self.next_timestamp.fill_(0) + self.last_timestamp_lasts.fill_(0) + def _allocate_more(self): """ Dynamically allocates more memory for the internal buffers. @@ -190,7 +211,10 @@ def _allocate_more(self): (self.transcript_wb_prev_ptr, torch.full_like(self.transcript_wb_prev_ptr, fill_value=INIT_POINTER_VALUE)), dim=-1, ) - self.timestamps = torch.cat((self.timestamps, torch.zeros_like(self.timestamps)), dim=-1) + if self.model_type == ASRModelTypeEnum.CTC: + self.timestamps = self._create_timestamps_tensor(2 * self._max_length) + else: + self.timestamps = torch.cat((self.timestamps, torch.zeros_like(self.timestamps)), dim=-1) self._max_length *= 2 @@ -205,13 +229,14 @@ def add_results_( Updates batch of beam hypotheses with labels. If the maximum allowed length is exceeded, underlying memory is doubled. Args: - hyps_indices (torch.Tensor): Indices of the hypotheses to be updated. + next_indices (torch.Tensor): Indices of the hypotheses to be updated. next_labels (torch.Tensor): Labels corresponding to the next step in the beam search. next_hyps_prob (torch.Tensor): Probabilities of the next hypotheses. + next_label_durations (torch.Tensor, optional): Durations associated with the next labels. Required when `model_type='tdt'`. """ - if self.is_tdt and next_label_durations is None: - raise ValueError("`next_label_durations` is required when `self.is_tdt` is True.") + if self.model_type == ASRModelTypeEnum.TDT and next_label_durations is None: + raise ValueError("`next_label_durations` is required when model type is TDT.") if (self.current_lengths_wb + 1).max() >= self._max_length: self._allocate_more() @@ -233,15 +258,15 @@ def add_results_no_checks_( """ Updates batch of beam hypotheses with labels. Args: - hyps_indices (torch.Tensor): Indices of the hypotheses to be updated. + next_indices (torch.Tensor): Indices of the hypotheses to be updated. next_labels (torch.Tensor): Labels corresponding to the next step in the beam search. next_hyps_prob (torch.Tensor): Probabilities of the next hypotheses. - next_label_durations (torch.Tensor, optional): Durations associated with the next labels. Required when `is_tdt=True`. + next_label_durations (torch.Tensor, optional): Durations associated with the next labels. Required when `model_type='tdt'`. """ - if self.is_tdt and next_label_durations is None: - raise ValueError("`next_label_durations` is required when `self.is_tdt` is True.") + if self.model_type == ASRModelTypeEnum.TDT and next_label_durations is None: + raise ValueError("`next_label_durations` is required when model type is TDT.") - timesteps = torch.gather(self.next_timestamp, dim=-1, index=next_indices) + last_labels = torch.gather(self.last_label, dim=-1, index=next_indices) self.transcript_wb.scatter_(dim=-1, index=self.current_lengths_wb.unsqueeze(-1), src=next_labels.unsqueeze(-1)) self.transcript_wb_prev_ptr.scatter_( dim=-1, index=self.current_lengths_wb.unsqueeze(-1), src=next_indices.unsqueeze(-1) @@ -249,9 +274,13 @@ def add_results_no_checks_( is_extended = next_labels >= 0 extended_with_blank = next_labels == self.blank_index - extended_with_label = (~extended_with_blank) & (is_extended) + extended_with_label = (is_extended) & (~extended_with_blank) + if self.model_type == ASRModelTypeEnum.CTC: + # for CTC last non-blank and non-repeated label + extended_with_label = (extended_with_label) & (next_labels != last_labels) # non-repeated non-blank label - if not self.is_tdt: + if self.model_type == ASRModelTypeEnum.RNNT: + timesteps = torch.gather(self.next_timestamp, dim=-1, index=next_indices) self.timestamps.scatter_( dim=-1, index=self.current_lengths_wb.unsqueeze(-1), @@ -264,7 +293,8 @@ def add_results_no_checks_( torch.gather(self.last_timestamp_lasts, dim=-1, index=next_indices) + extended_with_label, out=self.last_timestamp_lasts, ) - else: + elif self.model_type == ASRModelTypeEnum.TDT: + timesteps = torch.gather(self.next_timestamp, dim=-1, index=next_indices) next_label_durations = torch.where(is_extended, next_label_durations, 0) self.timestamps.scatter_( dim=-1, @@ -286,14 +316,6 @@ def add_results_no_checks_( self.scores.copy_(next_hyps_prob) prev_transcript_hash = torch.gather(self.transcript_hash, dim=-1, index=next_indices) - # track last label - torch.where( - extended_with_label, - next_labels, - torch.gather(self.last_label, dim=-1, index=next_indices), - out=self.last_label, - ) - # update hashes and prefix hashes torch.where( extended_with_label, @@ -301,13 +323,22 @@ def add_results_no_checks_( prev_transcript_hash, out=self.transcript_hash, ) + + if self.model_type == ASRModelTypeEnum.CTC: + # track last label + torch.where(is_extended, next_labels, last_labels, out=self.last_label) + else: + # track last non-blank label + torch.where(extended_with_label, next_labels, last_labels, out=self.last_label) + + # store prefix hashes for batched maes if self.store_prefix_hashes: prev_transcript_prefix_hash = torch.gather(self.transcript_prefix_hash, dim=-1, index=next_indices) torch.where( extended_with_label, prev_transcript_hash, prev_transcript_prefix_hash, out=self.transcript_prefix_hash ) - def recombine_hyps(self): + def recombine_hyps_(self): """ Recombines hypotheses in the beam search by merging equivalent hypotheses and updating their scores. This method identifies hypotheses that are equivalent based on their transcript hash, last label, @@ -326,7 +357,7 @@ def recombine_hyps(self): & (self.current_lengths_nb[:, :, None] == self.current_lengths_nb[:, None, :]) ) - if self.is_tdt: + if self.model_type == ASRModelTypeEnum.TDT: hyps_equal &= self.next_timestamp[:, :, None] == self.next_timestamp[:, None, :] scores_matrix = torch.where( @@ -338,7 +369,10 @@ def recombine_hyps(self): scores_to_keep = ( torch.arange(self.beam_size, device=scores_argmax.device, dtype=torch.long)[None, :] == scores_argmax ) - new_scores = torch.logsumexp(scores_matrix, dim=-1, keepdim=False) + if self.model_type == ASRModelTypeEnum.CTC: + new_scores = torch.max(scores_matrix, dim=-1, keepdim=False).values + else: + new_scores = torch.logsumexp(scores_matrix, dim=-1, keepdim=False) torch.where(scores_to_keep, new_scores.to(self.scores.dtype), self.INACTIVE_SCORE_TENSOR, out=self.scores) def remove_duplicates(self, labels: torch.Tensor, total_logps: torch.Tensor): @@ -433,21 +467,21 @@ def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]: list[Hypothesis]: A list where each element corresponds to a batch and contains best hypothesis. """ - self.flatten_sort_(score_norm) scores = self.scores[self.batch_indices, 0].tolist() max_idx = self.current_lengths_wb.max() - 1 - timestamps = self.timestamps[..., 0, : max_idx + 1].cpu().detach().numpy() - transcripts = self.transcript_wb[..., 0, : max_idx + 1].cpu().detach().numpy() + timestamps = self.timestamps[..., 0, : max_idx + 1] + transcripts = self.transcript_wb[..., 0, : max_idx + 1] hypotheses = [ Hypothesis( score=scores[batch_idx], - y_sequence=transcripts[batch_idx][ - mask := (transcripts[batch_idx] >= 0) & (transcripts[batch_idx] != self.blank_index) - ], - timestamp=timestamps[batch_idx][mask], + y_sequence=transcripts[batch_idx][mask := self._create_transcripts_mask(transcripts[batch_idx])] + .cpu() + .detach() + .numpy(), + timestamp=timestamps[batch_idx][mask].cpu().detach().numpy(), alignments=None, dec_state=None, ) @@ -470,18 +504,20 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: scores = self.scores.tolist() max_idx = self.current_lengths_wb.max() - 1 - transcripts = self.transcript_wb[..., : max_idx + 1].cpu().detach().numpy() - timestamps = self.timestamps[..., : max_idx + 1].cpu().detach().numpy() + transcripts = self.transcript_wb[..., : max_idx + 1] + timestamps = self.timestamps[..., : max_idx + 1] hypotheses = [ NBestHypotheses( [ Hypothesis( score=scores[batch_idx][beam_idx], y_sequence=transcripts[batch_idx][beam_idx][ - mask := (transcripts[batch_idx][beam_idx] >= 0) - & (transcripts[batch_idx][beam_idx] != self.blank_index) - ], - timestamp=timestamps[batch_idx][beam_idx][mask], + mask := self._create_transcripts_mask(transcripts[batch_idx][beam_idx]) + ] + .cpu() + .detach() + .numpy(), + timestamp=timestamps[batch_idx][beam_idx][mask].cpu().detach().numpy(), alignments=None, dec_state=None, ) @@ -499,8 +535,6 @@ def flatten_sort_(self, score_norm: bool = True): Args: score_norm (bool, optional): If True, normalizes the scores by dividing them by the current lengths of the hypotheses plus one. Defaults to True. - Returns: - list[Hypothesis]: A list of sorted and flattened hypotheses. This method performs the following steps: 1. Normalizes the scores if `score_norm` is True. 2. Sorts the normalized scores in descending order and retrieves the corresponding indices. @@ -520,7 +554,8 @@ def flatten_sort_(self, score_norm: bool = True): for idx in range(max_idx, -1, -1): self.transcript_wb[..., idx].copy_(self.transcript_wb[self.batch_indices.unsqueeze(-1), ptrs, idx]) - self.timestamps[..., idx].copy_(self.timestamps[self.batch_indices.unsqueeze(-1), ptrs, idx]) + if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT: + self.timestamps[..., idx].copy_(self.timestamps[self.batch_indices.unsqueeze(-1), ptrs, idx]) ptrs = self.transcript_wb_prev_ptr[self.batch_indices.unsqueeze(-1), ptrs, idx] self.transcript_wb_prev_ptr[..., : max_idx + 1].copy_(self.beam_indices.unsqueeze(0).unsqueeze(-1)) @@ -529,9 +564,60 @@ def flatten_sort_(self, score_norm: bool = True): self.current_lengths_wb.copy_(torch.gather(self.current_lengths_wb, dim=-1, index=indices)) self.last_label.copy_(torch.gather(self.last_label, dim=-1, index=indices)) - self.next_timestamp.copy_(torch.gather(self.next_timestamp, dim=-1, index=indices)) - self.last_timestamp_lasts.copy_(torch.gather(self.last_timestamp_lasts, dim=-1, index=indices)) + + if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT: + self.next_timestamp.copy_(torch.gather(self.next_timestamp, dim=-1, index=indices)) + self.last_timestamp_lasts.copy_(torch.gather(self.last_timestamp_lasts, dim=-1, index=indices)) self.transcript_hash.copy_(torch.gather(self.transcript_hash, dim=-1, index=indices)) if self.store_prefix_hashes: self.transcript_prefix_hash.copy_(torch.gather(self.transcript_prefix_hash, dim=-1, index=indices)) + + def _create_fold_consecutive_mask(self, transcript): + """ + Creates a mask to filter consecutive duplicates, blanks, and invalid tokens in a transcript. + Args: + transcript (torch.Tensor): 1D tensor of token sequence. + Returns: + torch.Tensor: Boolean mask indicating valid tokens. + """ + device = transcript.device + mask = ( + (transcript >= 0) + & torch.cat([torch.tensor([True], device=device), transcript[1:] != transcript[:-1]]) + & (transcript != self.blank_index) + ) + + return mask + + def _create_timestamps_tensor(self, max_time): + """ + Generates a tensor of timestamps. + + In CTC, labels align with input frames, allowing timestamps to be precomputed. + + Args: + max_time (int): The maximum number of time steps (frames) to include in the tensor. + + Returns: + torch.Tensor: A tensor of shape (batch_size, beam_size, max_time) containing + sequential timestamps for each batch and beam. + """ + return torch.arange(max_time, device=self.device, dtype=torch.long)[None, None, :].repeat( + self.batch_size, self.beam_size, 1 + ) + + def _create_transcripts_mask(self, transcripts: torch.Tensor): + """ + Processes the transcripts. + For RNN-T and TDT removes blanks. + For CTC removes remove consecutive duplicates and blanks. + Args: + transcripts (torch.Tensor): 1D tensor of token sequence. + Returns: + torch.Tensor: Binary mask indicating valid tokens. + """ + if self.model_type == ASRModelTypeEnum.CTC: + return self._create_fold_consecutive_mask(transcripts) + else: + return (transcripts >= 0) & (transcripts != self.blank_index) diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py index 9735180d2659..72e5bb70f961 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py @@ -147,10 +147,10 @@ def beam_search_eval( # Override the beam search config with current search candidate configuration cfg.decoding.beam_size = beam_width - cfg.decoding.beam_alpha = beam_alpha + cfg.decoding.ngram_lm_alpha = beam_alpha cfg.decoding.beam_beta = beam_beta cfg.decoding.return_best_hypothesis = False - cfg.decoding.kenlm_path = cfg.kenlm_model_file + cfg.decoding.ngram_lm_model = cfg.kenlm_model_file # Update model's decoding strategy config model.cfg.decoding.strategy = cfg.decoding_strategy diff --git a/tests/collections/asr/decoding/test_batched_rnnt_decoding.py b/tests/collections/asr/decoding/test_batched_beam_decoding.py similarity index 76% rename from tests/collections/asr/decoding/test_batched_rnnt_decoding.py rename to tests/collections/asr/decoding/test_batched_beam_decoding.py index 2c9a0a3522fe..53aecad794f6 100644 --- a/tests/collections/asr/decoding/test_batched_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_batched_beam_decoding.py @@ -24,6 +24,8 @@ from tqdm import tqdm from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.submodules.ctc_beam_decoding import BeamBatchedCTCInfer from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import BeamBatchedRNNTInfer from nemo.collections.asr.parts.submodules.tdt_beam_decoding import BeamBatchedTDTInfer @@ -33,6 +35,7 @@ from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ RNNT_MODEL = "stt_en_conformer_transducer_small" +CTC_MODEL = "nvidia/stt_en_conformer_ctc_small" TDT_MODEL = "nvidia/stt_en_fastconformer_tdt_large" MAX_SAMPLES = 10 @@ -74,16 +77,31 @@ def tdt_model(): return model +@pytest.fixture(scope="module") +def ctc_model(): + model = ASRModel.from_pretrained(model_name=CTC_MODEL, map_location="cpu") + model.eval() + return model + + # encoder output fixtures @pytest.fixture(scope="module") def get_rnnt_encoder_output(rnnt_model, test_audio_filenames): - encoder_output, encoded_lengths = get_model_encoder_output(test_audio_filenames, MAX_SAMPLES, rnnt_model) + encoder_output, encoded_lengths = get_transducer_model_encoder_output( + test_audio_filenames, MAX_SAMPLES, rnnt_model + ) return encoder_output, encoded_lengths @pytest.fixture(scope="module") def get_tdt_encoder_output(tdt_model, test_audio_filenames): - encoder_output, encoded_lengths = get_model_encoder_output(test_audio_filenames, MAX_SAMPLES, tdt_model) + encoder_output, encoded_lengths = get_transducer_model_encoder_output(test_audio_filenames, MAX_SAMPLES, tdt_model) + return encoder_output, encoded_lengths + + +@pytest.fixture(scope="module") +def get_ctc_output(ctc_model, test_audio_filenames): + encoder_output, encoded_lengths = get_ctc_model_output(test_audio_filenames, MAX_SAMPLES, ctc_model) return encoder_output, encoded_lengths @@ -96,7 +114,7 @@ def kenlm_model_path(tmp_path_factory, test_data_dir): return f"{lm_nemo_path}" -def get_model_encoder_output( +def get_transducer_model_encoder_output( test_audio_filenames, num_samples: int, model: ASRModel, @@ -124,6 +142,34 @@ def get_model_encoder_output( return encoded_outputs, encoded_length +def get_ctc_model_output( + test_audio_filenames, + num_samples: int, + model: ASRModel, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, +): + audio_filepaths = test_audio_filenames[:num_samples] + + with torch.no_grad(): + model.preprocessor.featurizer.dither = 0.0 + model.preprocessor.featurizer.pad_to = 0 + model.eval() + + all_inputs, all_lengths = [], [] + for audio_file in tqdm(audio_filepaths, desc="Loading audio files"): + audio_tensor, _ = load_audio(audio_file) + all_inputs.append(audio_tensor) + all_lengths.append(torch.tensor(audio_tensor.shape[0], dtype=torch.int64)) + + input_batch = torch.nn.utils.rnn.pad_sequence(all_inputs, batch_first=True).to(device=device, dtype=dtype) + length_batch = torch.tensor(all_lengths, dtype=torch.int64).to(device) + + log_probs, encoded_length, _ = model(input_signal=input_batch, input_signal_length=length_batch) + + return log_probs, encoded_length + + def print_unit_test_info(strategy, batch_size, beam_size, allow_cuda_graphs, device): print( f"""Beam search algorithm: {strategy}, @@ -198,12 +244,20 @@ def print_res_nbest_hyps(batch_nbest_hyps): print() -def decode_text_from_hypotheses(hyps, decoding): - return decoding.decode_hypothesis(hyps) +def decode_text_from_hypotheses(hyps, model): + if isinstance(model, EncDecCTCModel): + return model.decoding.decode_hypothesis(hyps, fold_consecutive=False) + else: + return model.decoding.decode_hypothesis(hyps) -def decode_text_from_nbest_hypotheses(hyps, decoding): - return [decoding.decode_hypothesis(nbest_hyp.n_best_hypotheses) for nbest_hyp in hyps] +def decode_text_from_nbest_hypotheses(hyps, model): + if isinstance(model, EncDecCTCModel): + return [ + model.decoding.decode_hypothesis(nbest_hyp.n_best_hypotheses, fold_consecutive=False) for nbest_hyp in hyps + ] + else: + return [model.decoding.decode_hypothesis(nbest_hyp.n_best_hypotheses) for nbest_hyp in hyps] class TestRNNTDecoding: @@ -257,7 +311,7 @@ def test_rnnt_beam_decoding_return_best_hypothesis( hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] check_res_best_hyps(num_samples, hyps) - hyps = decode_text_from_hypotheses(hyps, model.decoding) + hyps = decode_text_from_hypotheses(hyps, model) print_res_best_hyps(hyps) @pytest.mark.skipif( @@ -311,7 +365,7 @@ def test_rnnt_beam_decoding_return_nbest( batch_nbest_hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] check_res_nbest_hyps(num_samples, batch_nbest_hyps) - batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model.decoding) + batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model) print_res_nbest_hyps(batch_nbest_hyps) @pytest.mark.skipif( @@ -381,7 +435,7 @@ def test_rnnt_beam_decoding_kenlm( hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] check_res_best_hyps(num_samples, hyps) - hyps = decode_text_from_hypotheses(hyps, model.decoding) + hyps = decode_text_from_hypotheses(hyps, model) print_res_best_hyps(hyps) @@ -439,7 +493,7 @@ def test_tdt_beam_decoding_return_best_hypothesis( hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] check_res_best_hyps(num_samples, hyps) - hyps = decode_text_from_hypotheses(hyps, model.decoding) + hyps = decode_text_from_hypotheses(hyps, model) print_res_best_hyps(hyps) @pytest.mark.skipif( @@ -496,7 +550,7 @@ def test_tdt_beam_decoding_return_nbest( batch_nbest_hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] check_res_nbest_hyps(num_samples, batch_nbest_hyps) - batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model.decoding) + batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model) print_res_nbest_hyps(batch_nbest_hyps) @pytest.mark.skipif( @@ -577,7 +631,7 @@ def test_tdt_beam_decoding_kenlm( hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] check_res_best_hyps(num_samples, hyps) - hyps = decode_text_from_hypotheses(hyps, model.decoding) + hyps = decode_text_from_hypotheses(hyps, model) print_res_best_hyps(hyps) @@ -685,3 +739,148 @@ def test_stated_stateless_bf16(self, test_audio_filenames, rnnt_model, tdt_model with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True): model.transcribe(test_audio_filenames, batch_size=batch_size, num_workers=None) + + +class TestCTCDecoding: + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + {"allow_cuda_graphs": False}, + {"allow_cuda_graphs": True}, + ], + ) + @pytest.mark.parametrize("beam_size", [4]) + @pytest.mark.parametrize("batch_size", [4, 16]) + @pytest.mark.parametrize("device", DEVICES) + def test_ctc_beam_decoding_return_best_hypothesis( + self, test_audio_filenames, ctc_model, get_ctc_output, beam_config, device, batch_size, beam_size + ): + num_samples = min(batch_size, len(test_audio_filenames)) + model = ctc_model.to(device) + log_probs, encoded_lengths = get_ctc_output + log_probs, encoded_lengths = log_probs[:num_samples].to(device), encoded_lengths[:num_samples].to(device) + + vocab_size = model.tokenizer.vocab_size + decoding = BeamBatchedCTCInfer( + blank_index=vocab_size, + beam_size=beam_size, + return_best_hypothesis=True, + **beam_config, + ) + + print_unit_test_info( + strategy="beam_batch", + batch_size=batch_size, + beam_size=beam_size, + allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), + device=device, + ) + + with torch.no_grad(): + hyps = decoding(decoder_output=log_probs, decoder_lengths=encoded_lengths)[0] + + check_res_best_hyps(num_samples, hyps) + hyps = decode_text_from_hypotheses(hyps, model) + print_res_best_hyps(hyps) + + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") + @pytest.mark.parametrize( + "beam_config", + [ + {"allow_cuda_graphs": False}, + {"allow_cuda_graphs": True}, + ], + ) + @pytest.mark.parametrize("beam_size", [4]) + @pytest.mark.parametrize("batch_size", [4]) + def test_ctc_beam_decoding_return_nbest( + self, test_audio_filenames, ctc_model, get_ctc_output, beam_config, device, beam_size, batch_size + ): + device = torch.device("cuda") + num_samples = min(batch_size, len(test_audio_filenames)) + model = ctc_model.to(device) + log_probs, encoded_lengths = get_ctc_output + log_probs, encoded_lengths = log_probs[:num_samples].to(device), encoded_lengths[:num_samples].to(device) + + vocab_size = model.tokenizer.vocab_size + decoding = BeamBatchedCTCInfer( + blank_index=vocab_size, + beam_size=beam_size, + return_best_hypothesis=False, + **beam_config, + ) + + print_unit_test_info( + strategy="beam_batch", + batch_size=batch_size, + beam_size=beam_size, + allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), + device=device, + ) + + with torch.no_grad(): + batch_nbest_hyps = decoding(decoder_output=log_probs, decoder_lengths=encoded_lengths)[0] + + check_res_nbest_hyps(num_samples, batch_nbest_hyps) + batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model) + print_res_nbest_hyps(batch_nbest_hyps) + + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") + @pytest.mark.parametrize( + "beam_config", + [ + {"allow_cuda_graphs": False, "ngram_lm_alpha": 0.3, "beam_beta": 1.0}, + {"allow_cuda_graphs": False, "ngram_lm_alpha": 0.3, "beam_beta": 1.0}, + ], + ) + @pytest.mark.parametrize("batch_size", [4]) + @pytest.mark.parametrize("beam_size", [4]) + def test_ctc_beam_decoding_kenlm( + self, + kenlm_model_path, + test_audio_filenames, + ctc_model, + get_ctc_output, + beam_config, + device, + batch_size, + beam_size, + ): + device = torch.device("cuda") + beam_config["ngram_lm_model"] = kenlm_model_path + + num_samples = min(batch_size, len(test_audio_filenames)) + model = ctc_model.to(device) + decoder_output, decoder_lengths = get_ctc_output + decoder_output, decoder_lengths = decoder_output[:num_samples].to(device), decoder_lengths[:num_samples].to( + device + ) + + vocab_size = model.tokenizer.vocab_size + decoding = BeamBatchedCTCInfer( + blank_index=vocab_size, + beam_size=beam_size, + return_best_hypothesis=True, + **beam_config, + ) + + print_unit_test_info( + strategy="beam_batch", + batch_size=batch_size, + beam_size=beam_size, + allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), + device=device, + ) + + with torch.no_grad(): + hyps = decoding(decoder_output=decoder_output, decoder_lengths=decoder_lengths)[0] + + check_res_best_hyps(num_samples, hyps) + hyps = decode_text_from_hypotheses(hyps, model) + print_res_best_hyps(hyps) diff --git a/tests/collections/asr/decoding/test_batched_beam_hyps.py b/tests/collections/asr/decoding/test_batched_beam_hyps.py index a69899d0c69d..d9a41dd924e8 100644 --- a/tests/collections/asr/decoding/test_batched_beam_hyps.py +++ b/tests/collections/asr/decoding/test_batched_beam_hyps.py @@ -18,7 +18,7 @@ import pytest import torch -from nemo.collections.asr.parts.utils.rnnt_batched_beam_utils import ( +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INIT_POINTER_VALUE, NON_EXISTENT_LABEL_VALUE, BatchedBeamHyps, @@ -271,31 +271,35 @@ def test_rnnt_add_with_invalid_results(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_instantiate(self, device: torch.device): - _ = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=4, device=device, blank_index=1024, is_tdt=True) + _ = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=4, device=device, blank_index=1024, model_type='tdt' + ) @pytest.mark.unit @pytest.mark.parametrize("batch_size", [-1, 0]) def test_tdt_instantiate_incorrect_batch_size(self, batch_size: Literal[-1] | Literal[0]): with pytest.raises(ValueError): - _ = BatchedBeamHyps(batch_size=batch_size, beam_size=4, init_length=3, blank_index=1024, is_tdt=True) + _ = BatchedBeamHyps(batch_size=batch_size, beam_size=4, init_length=3, blank_index=1024, model_type='tdt') @pytest.mark.unit @pytest.mark.parametrize("beam_size", [-1, 0]) def test_tdt_instantiate_incorrect_beam_size(self, beam_size: Literal[-1] | Literal[0]): with pytest.raises(ValueError): - _ = BatchedBeamHyps(batch_size=2, beam_size=beam_size, init_length=3, blank_index=1024, is_tdt=True) + _ = BatchedBeamHyps(batch_size=2, beam_size=beam_size, init_length=3, blank_index=1024, model_type='tdt') @pytest.mark.unit @pytest.mark.parametrize("init_length", [-1, 0]) def test_tdt_instantiate_incorrect_init_length(self, init_length: Literal[-1] | Literal[0]): with pytest.raises(ValueError): - _ = BatchedBeamHyps(batch_size=1, beam_size=4, init_length=init_length, blank_index=1024, is_tdt=True) + _ = BatchedBeamHyps(batch_size=1, beam_size=4, init_length=init_length, blank_index=1024, model_type='tdt') @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_add_results(self, device: torch.device): # batch of size 2, add label for first utterance - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) assert hyps._max_length == 1 hyps.add_results_( next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), @@ -321,7 +325,9 @@ def test_tdt_add_results(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_add_multiple_results(self, device: torch.device): - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) assert hyps._max_length == 1 hyps.add_results_( @@ -375,7 +381,9 @@ def test_tdt_add_multiple_results(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_add_with_invalid_results(self, device: torch.device): - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) assert hyps._max_length == 1 hyps.add_results_( @@ -425,6 +433,198 @@ def test_tdt_add_with_invalid_results(self, device: torch.device): [[2, 4, 5, 0], [3, 4, 4, 0], [4, 3, 6, 0]], ] + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_ctc_instantiate(self, device: torch.device): + _ = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=4, device=device, blank_index=1024, model_type='ctc' + ) + + @pytest.mark.unit + @pytest.mark.parametrize("batch_size", [-1, 0]) + def test_ctc_instantiate_incorrect_batch_size(self, batch_size: Literal[-1] | Literal[0]): + with pytest.raises(ValueError): + _ = BatchedBeamHyps(batch_size=batch_size, beam_size=4, init_length=3, blank_index=1024, model_type='ctc') + + @pytest.mark.unit + @pytest.mark.parametrize("beam_size", [-1, 0]) + def test_ctc_instantiate_incorrect_beam_size(self, beam_size: Literal[-1] | Literal[0]): + with pytest.raises(ValueError): + _ = BatchedBeamHyps(batch_size=2, beam_size=beam_size, init_length=3, blank_index=1024, model_type='ctc') + + @pytest.mark.unit + @pytest.mark.parametrize("init_length", [-1, 0]) + def test_ctc_instantiate_incorrect_init_length(self, init_length: Literal[-1] | Literal[0]): + with pytest.raises(ValueError): + _ = BatchedBeamHyps(batch_size=1, beam_size=4, init_length=init_length, blank_index=1024) + + @pytest.mark.unit + @pytest.mark.parametrize("y", [torch.tensor([1, 1024, 1024, 2, 2, 1024, 2, 3, 3, 1024, 3, 2, 2, 2])]) + def test_ctc_create_fold_consecutive_mask(self, y: torch.Tensor): + batched_hyps = BatchedBeamHyps(batch_size=1, beam_size=4, init_length=30, blank_index=1024, model_type='ctc') + mask = batched_hyps._create_fold_consecutive_mask(transcript=y) + + assert y[mask].tolist() == [1, 2, 2, 3, 3, 2] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_ctc_add_results(self, device: torch.device): + # batch of size 2, add label for first utterance + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='ctc' + ) + assert hyps._max_length == 1 + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), + next_labels=torch.tensor([[0, 1024, 1], [2, 1024, 1024]], device=device), + next_hyps_prob=torch.tensor([[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]], device=device), + ) + assert hyps._max_length == 2 + assert hyps.current_lengths_nb.tolist() == [[1, 0, 1], [1, 0, 0]] + assert hyps.current_lengths_wb.tolist() == [[1, 1, 1], [1, 1, 1]] + assert_nested_lists_approx(actual=hyps.scores.tolist(), expected=[[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]]) + assert hyps.transcript_wb.tolist() == [ + [[0, NON_EXISTENT_LABEL_VALUE], [1024, NON_EXISTENT_LABEL_VALUE], [1, NON_EXISTENT_LABEL_VALUE]], + [[2, NON_EXISTENT_LABEL_VALUE], [1024, NON_EXISTENT_LABEL_VALUE], [1024, NON_EXISTENT_LABEL_VALUE]], + ] + assert hyps.transcript_wb_prev_ptr.tolist() == [ + [[0, INIT_POINTER_VALUE], [1, INIT_POINTER_VALUE], [2, INIT_POINTER_VALUE]], + [[0, INIT_POINTER_VALUE], [1, INIT_POINTER_VALUE], [2, INIT_POINTER_VALUE]], + ] + assert hyps.timestamps.tolist() == [ + [[0, 1], [0, 1], [0, 1]], + [[0, 1], [0, 1], [0, 1]], + ] + assert hyps.last_label.tolist() == [ + [0, 1024, 1], + [2, 1024, 1024], + ] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_rnnt_add_multiple_results(self, device: torch.device): + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='ctc' + ) + assert hyps._max_length == 1 + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), + next_labels=torch.tensor([[0, 1024, 1], [2, 1024, 1024]], device=device), + next_hyps_prob=torch.tensor([[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 1], [2, 1, 0]], device=device), + next_labels=torch.tensor([[3, 4, 1024], [5, 1024, 6]], device=device), + next_hyps_prob=torch.tensor([[0.3, 0.2, 0.1], [0.4, 0.5, 0.6]], device=device), + ) + + assert hyps._max_length == 4 + assert hyps.current_lengths_nb.tolist() == [[2, 1, 0], [1, 0, 2]] + assert hyps.current_lengths_wb.tolist() == [[2, 2, 2], [2, 2, 2]] + assert_nested_lists_approx(actual=hyps.scores.tolist(), expected=[[0.3, 0.2, 0.1], [0.4, 0.5, 0.6]]) + assert hyps.transcript_wb.tolist() == [ + [ + [0, 3, NON_EXISTENT_LABEL_VALUE, NON_EXISTENT_LABEL_VALUE], + [1024, 4, NON_EXISTENT_LABEL_VALUE, NON_EXISTENT_LABEL_VALUE], + [1, 1024, NON_EXISTENT_LABEL_VALUE, NON_EXISTENT_LABEL_VALUE], + ], + [ + [2, 5, NON_EXISTENT_LABEL_VALUE, NON_EXISTENT_LABEL_VALUE], + [1024, 1024, NON_EXISTENT_LABEL_VALUE, NON_EXISTENT_LABEL_VALUE], + [1024, 6, NON_EXISTENT_LABEL_VALUE, NON_EXISTENT_LABEL_VALUE], + ], + ] + assert hyps.transcript_wb_prev_ptr.tolist() == [ + [ + [0, 0, INIT_POINTER_VALUE, INIT_POINTER_VALUE], + [1, 1, INIT_POINTER_VALUE, INIT_POINTER_VALUE], + [2, 1, INIT_POINTER_VALUE, INIT_POINTER_VALUE], + ], + [ + [0, 2, INIT_POINTER_VALUE, INIT_POINTER_VALUE], + [1, 1, INIT_POINTER_VALUE, INIT_POINTER_VALUE], + [2, 0, INIT_POINTER_VALUE, INIT_POINTER_VALUE], + ], + ] + assert hyps.timestamps.tolist() == [ + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + ] + assert hyps.last_label.tolist() == [[3, 4, 1024], [5, 1024, 6]] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_rnnt_add_with_invalid_results(self, device: torch.device): + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='ctc' + ) + assert hyps._max_length == 1 + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), + next_labels=torch.tensor([[0, 1024, 1], [2, 1024, 1024]], device=device), + next_hyps_prob=torch.tensor([[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 1], [2, 1, 0]], device=device), + next_labels=torch.tensor([[3, 4, 1024], [5, 1024, 6]], device=device), + next_hyps_prob=torch.tensor([[0.3, 0.2, 0.1], [0.4, 0.5, 0.6]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[1, 0, 2], [2, 0, 1]], device=device), + next_labels=torch.tensor([[-1, 7, 8], [10, -1, 9]], device=device), + next_hyps_prob=torch.tensor([[0.35, 0.4, 0.1], [0.4, 0.55, 0.6]], device=device), + ) + + assert hyps._max_length == 4 + assert hyps.current_lengths_nb.tolist() == [[1, 3, 1], [3, 1, 1]] + assert hyps.current_lengths_wb.tolist() == [[3, 3, 3], [3, 3, 3]] + assert_nested_lists_approx(actual=hyps.scores.tolist(), expected=[[0.35, 0.4, 0.1], [0.4, 0.55, 0.6]]) + assert hyps.transcript_wb.tolist() == [ + [ + [0, 3, -1, NON_EXISTENT_LABEL_VALUE], + [1024, 4, 7, NON_EXISTENT_LABEL_VALUE], + [1, 1024, 8, NON_EXISTENT_LABEL_VALUE], + ], + [ + [2, 5, 10, NON_EXISTENT_LABEL_VALUE], + [1024, 1024, -1, NON_EXISTENT_LABEL_VALUE], + [1024, 6, 9, NON_EXISTENT_LABEL_VALUE], + ], + ] + assert hyps.transcript_wb_prev_ptr.tolist() == [ + [[0, 0, 1, INIT_POINTER_VALUE], [1, 1, 0, INIT_POINTER_VALUE], [2, 1, 2, INIT_POINTER_VALUE]], + [[0, 2, 2, INIT_POINTER_VALUE], [1, 1, 0, INIT_POINTER_VALUE], [2, 0, 1, INIT_POINTER_VALUE]], + ] + assert hyps.timestamps.tolist() == [ + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + ] + assert hyps.last_label.tolist() == [ + [4, 7, 8], + [10, 5, 9], + ] + class TestConvertToHypotheses: @pytest.mark.unit @@ -645,7 +845,9 @@ def test_rnnt_to_nbest_hyps_list(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_flatten_sort(self, device: torch.device): - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) hyps.add_results_( next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), @@ -698,7 +900,9 @@ def test_tdt_flatten_sort(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_flatten_sort_norm(self, device: torch.device): - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) hyps.add_results_( next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), @@ -751,7 +955,9 @@ def test_tdt_flatten_sort_norm(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_to_hyps_list(self, device: torch.device): - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) hyps.add_results_( next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), @@ -794,7 +1000,9 @@ def test_tdt_to_hyps_list(self, device: torch.device): @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) def test_tdt_to_nbest_hyps_list(self, device: torch.device): - hyps = BatchedBeamHyps(batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, is_tdt=True) + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='tdt' + ) hyps.add_results_( next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), @@ -847,3 +1055,163 @@ def test_tdt_to_nbest_hyps_list(self, device: torch.device): assert hypotheses[1].n_best_hypotheses[0].score == pytest.approx(0.6) assert hypotheses[1].n_best_hypotheses[1].score == pytest.approx(0.55) assert hypotheses[1].n_best_hypotheses[2].score == pytest.approx(0.4) + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_ctc_flatten_sort(self, device: torch.device): + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='ctc' + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), + next_labels=torch.tensor([[3, 1024, 1], [2, 1024, 1024]], device=device), + next_hyps_prob=torch.tensor([[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 1], [2, 1, 0]], device=device), + next_labels=torch.tensor([[3, 4, 1024], [5, 1024, 6]], device=device), + next_hyps_prob=torch.tensor([[0.3, 0.2, 0.1], [0.4, 0.5, 0.6]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[1, 0, 2], [2, 0, 1]], device=device), + next_labels=torch.tensor([[-1, 7, 8], [2, -1, 9]], device=device), + next_hyps_prob=torch.tensor([[0.35, 0.4, 0.1], [0.4, 0.55, 0.6]], device=device), + ) + hyps.flatten_sort_(score_norm=False) + + assert hyps.current_lengths_nb.tolist() == [[2, 1, 1], [1, 1, 3]] + assert hyps.current_lengths_wb.tolist() == [[3, 3, 3], [3, 3, 3]] + assert_nested_lists_approx(actual=hyps.scores.tolist(), expected=[[0.4, 0.35, 0.1], [0.6, 0.55, 0.4]]) + assert hyps.transcript_wb.tolist() == [ + [ + [3, 3, 7, NON_EXISTENT_LABEL_VALUE], + [1024, 4, -1, NON_EXISTENT_LABEL_VALUE], + [1024, 1024, 8, NON_EXISTENT_LABEL_VALUE], + ], + [ + [1024, 1024, 9, NON_EXISTENT_LABEL_VALUE], + [1024, 5, -1, NON_EXISTENT_LABEL_VALUE], + [2, 6, 2, NON_EXISTENT_LABEL_VALUE], + ], + ] + assert hyps.transcript_wb_prev_ptr.tolist() == [ + [[0, 0, 0, INIT_POINTER_VALUE], [1, 1, 1, INIT_POINTER_VALUE], [2, 2, 2, INIT_POINTER_VALUE]], + [[0, 0, 0, INIT_POINTER_VALUE], [1, 1, 1, INIT_POINTER_VALUE], [2, 2, 2, INIT_POINTER_VALUE]], + ] + assert hyps.timestamps.tolist() == [ + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + ] + assert hyps.last_label.tolist() == [ + [7, 4, 8], + [9, 5, 2], + ] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_ctc_to_hyps_list(self, device: torch.device): + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='ctc' + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), + next_labels=torch.tensor([[3, 1024, 1], [2, 1024, 1024]], device=device), + next_hyps_prob=torch.tensor([[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 1], [2, 1, 0]], device=device), + next_labels=torch.tensor([[3, 4, 1024], [5, 1024, 6]], device=device), + next_hyps_prob=torch.tensor([[0.3, 0.2, 0.1], [0.4, 0.5, 0.6]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[1, 0, 2], [2, 0, 1]], device=device), + next_labels=torch.tensor([[-1, 7, 8], [2, -1, 9]], device=device), + next_hyps_prob=torch.tensor([[0.35, 0.4, 0.1], [0.4, 0.55, 0.6]], device=device), + ) + + hypotheses = hyps.to_hyps_list(score_norm=False) + + assert type(hypotheses) == list + assert type(hypotheses[0]) == Hypothesis + assert type(hypotheses[1]) == Hypothesis + + assert len(hypotheses) == 2 + + assert_hyps_sequence_equal(hypotheses[0].y_sequence, [3, 7]) + assert_hyps_sequence_equal(hypotheses[1].y_sequence, [9]) + + assert_hyps_timestamps_equal(hypotheses[0].timestamp, [0, 2]) + assert_hyps_timestamps_equal(hypotheses[1].timestamp, [2]) + + assert hypotheses[0].score == pytest.approx(0.4) + assert hypotheses[1].score == pytest.approx(0.6) + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_ctc_to_nbest_hyps_list(self, device: torch.device): + hyps = BatchedBeamHyps( + batch_size=2, beam_size=3, init_length=1, device=device, blank_index=1024, model_type='ctc' + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), + next_labels=torch.tensor([[3, 1024, 1], [2, 1024, 1024]], device=device), + next_hyps_prob=torch.tensor([[0.5, 0.6, 0.8], [0.1, 0.2, 0.3]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[0, 1, 1], [2, 1, 0]], device=device), + next_labels=torch.tensor([[3, 4, 1024], [5, 1024, 6]], device=device), + next_hyps_prob=torch.tensor([[0.3, 0.2, 0.1], [0.4, 0.5, 0.6]], device=device), + ) + + hyps.add_results_( + next_indices=torch.tensor([[1, 0, 2], [2, 0, 1]], device=device), + next_labels=torch.tensor([[-1, 7, 8], [2, -1, 9]], device=device), + next_hyps_prob=torch.tensor([[0.35, 0.4, 0.1], [0.4, 0.55, 0.6]], device=device), + ) + + hypotheses = hyps.to_nbest_hyps_list(score_norm=False) + + assert type(hypotheses) == list + assert type(hypotheses[0]) == NBestHypotheses + assert type(hypotheses[1]) == NBestHypotheses + + assert len(hypotheses) == 2 + assert len(hypotheses[0].n_best_hypotheses) == 3 + assert len(hypotheses[1].n_best_hypotheses) == 3 + + assert_hyps_sequence_equal(hypotheses[0].n_best_hypotheses[0].y_sequence, [3, 7]) + assert_hyps_sequence_equal(hypotheses[0].n_best_hypotheses[1].y_sequence, [4]) + assert_hyps_sequence_equal(hypotheses[0].n_best_hypotheses[2].y_sequence, [8]) + assert_hyps_sequence_equal(hypotheses[1].n_best_hypotheses[0].y_sequence, [9]) + assert_hyps_sequence_equal(hypotheses[1].n_best_hypotheses[1].y_sequence, [5]) + assert_hyps_sequence_equal(hypotheses[1].n_best_hypotheses[2].y_sequence, [2, 6, 2]) + + assert_hyps_timestamps_equal(hypotheses[0].n_best_hypotheses[0].timestamp, [0, 2]) + assert_hyps_timestamps_equal(hypotheses[0].n_best_hypotheses[1].timestamp, [1]) + assert_hyps_timestamps_equal(hypotheses[0].n_best_hypotheses[2].timestamp, [2]) + assert_hyps_timestamps_equal(hypotheses[1].n_best_hypotheses[0].timestamp, [2]) + assert_hyps_timestamps_equal(hypotheses[1].n_best_hypotheses[1].timestamp, [1]) + assert_hyps_timestamps_equal(hypotheses[1].n_best_hypotheses[2].timestamp, [0, 1, 2]) + + assert hypotheses[0].n_best_hypotheses[0].score == pytest.approx(0.4) + assert hypotheses[0].n_best_hypotheses[1].score == pytest.approx(0.35) + assert hypotheses[0].n_best_hypotheses[2].score == pytest.approx(0.1) + assert hypotheses[1].n_best_hypotheses[0].score == pytest.approx(0.6) + assert hypotheses[1].n_best_hypotheses[1].score == pytest.approx(0.55) + assert hypotheses[1].n_best_hypotheses[2].score == pytest.approx(0.4)