diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index d657c56a67b6..55824ff732f4 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -504,13 +504,14 @@ def update_feature_buffer(self, chunk): class AudioFeatureIterator(IterableDataset): - def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=True): + def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=True, min_frame_len=1): self._samples = samples self._frame_len = frame_len self._start = 0 self.output = True self.count = 0 self.pad_to_frame_len = pad_to_frame_len + self.min_frame_len = min_frame_len timestep_duration = preprocessor._cfg['window_stride'] self._feature_frame_len = frame_len / timestep_duration audio_signal = torch.from_numpy(self._samples).unsqueeze_(0).to(device) @@ -532,13 +533,16 @@ def __next__(self): frame = self._features[:, self._start : last].cpu() self._start = last else: - if not self.pad_to_frame_len: - frame = self._features[:, self._start : self._features_len[0]].cpu() + self.output = False + segment = self._features[:, self._start : self._features_len[0]].cpu() + if segment.shape[1] == 0: + raise StopIteration + if not self.pad_to_frame_len and segment.shape[1] >= self.min_frame_len: + frame = segment else: - frame = np.zeros([self._features.shape[0], int(self._feature_frame_len)], dtype='float32') - segment = self._features[:, self._start : self._features_len[0]].cpu() + target_frame_len = int(self._feature_frame_len) if self.pad_to_frame_len else self.min_frame_len + frame = np.zeros([self._features.shape[0], target_frame_len], dtype='float32') frame[:, : segment.shape[1]] = segment - self.output = False self.count += 1 return frame @@ -1811,6 +1815,8 @@ def __init__(self, asr_model, frame_len=4, total_buffer=4, batch_size=4): super().__init__(asr_model, frame_len, total_buffer, batch_size, pad_to_buffer_len=False) self.window_stride = asr_model._cfg.preprocessor.window_stride self.subsampling_factor = asr_model._cfg.encoder.subsampling_factor + self.min_input_frames = None + self.timestamps_min_input_frames = None self.chunk_offsets = [ 0, ] # chunk offsets in terms of num frames before subsampling @@ -1883,9 +1889,17 @@ def read_audio_file(self, audio_filepath: str, delay, model_stride_in_secs, meta self.input_tokens = self.get_input_tokens(meta_data) samples = get_samples(audio_filepath) padded_samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate))) + min_frame_len = self._get_min_input_frames() + if timestamps and self.timestamps_asr_model is not None: + min_frame_len = max(min_frame_len, self._get_timestamps_min_input_frames()) frame_reader = AudioFeatureIterator( - padded_samples, self.frame_len, self.raw_preprocessor, self.asr_model.device, pad_to_frame_len=False + padded_samples, + self.frame_len, + self.raw_preprocessor, + self.asr_model.device, + pad_to_frame_len=False, + min_frame_len=min_frame_len, ) self.set_frame_reader(frame_reader) if timestamps and self.timestamps_asr_model is not None: @@ -1901,6 +1915,7 @@ def read_audio_file(self, audio_filepath: str, delay, model_stride_in_secs, meta self.frame_len, self.timestamps_frame_asr.raw_preprocessor, self.timestamps_frame_asr.asr_model.device, + min_frame_len=min_frame_len, ) self.timestamps_frame_asr.set_frame_reader(ts_model_frame_reader) @@ -1909,6 +1924,19 @@ def _get_batch_preds(self, keep_logits=False, timestamps=False): device = self.asr_model.device for batch in iter(self.data_loader): feat_signal, feat_signal_len = batch + min_input_frames = self._get_min_input_frames() + short_chunk_mask = feat_signal_len < min_input_frames + if short_chunk_mask.any(): + short_chunk_lengths = feat_signal_len[short_chunk_mask].tolist() + logging.warning( + "Zero-padding %d chunk(s) shorter than the minimum encoder input length (%d frames): %s", + len(short_chunk_lengths), + min_input_frames, + short_chunk_lengths, + ) + if feat_signal.size(-1) < min_input_frames: + feat_signal = torch.nn.functional.pad(feat_signal, (0, min_input_frames - feat_signal.size(-1))) + feat_signal_len = feat_signal_len.clamp(min=min_input_frames) # keep track of chunk offsets self.chunk_offsets.extend(feat_signal_len.tolist()) feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device) @@ -1930,6 +1958,40 @@ def _get_batch_preds(self, keep_logits=False, timestamps=False): self.all_preds.extend(predictions) del predictions + def _get_min_input_frames(self): + if self.min_input_frames is None: + self.min_input_frames = self._estimate_min_input_frames(self.asr_model) + return self.min_input_frames + + def _get_timestamps_min_input_frames(self): + if self.timestamps_min_input_frames is None: + self.timestamps_min_input_frames = self._estimate_min_input_frames(self.timestamps_asr_model) + return self.timestamps_min_input_frames + + @staticmethod + def _estimate_min_input_frames(asr_model, max_candidate_frames=64): + encoder = asr_model.encoder + feat_in = asr_model._cfg.preprocessor.features + encoder_dtype = next(encoder.parameters()).dtype + encoder_device = asr_model.device + encoder_was_training = encoder.training + + try: + encoder.eval() + for candidate_frames in range(1, max_candidate_frames + 1): + test_signal = torch.zeros(1, feat_in, candidate_frames, device=encoder_device, dtype=encoder_dtype) + test_signal_len = torch.tensor([candidate_frames], device=encoder_device) + try: + encoder(audio_signal=test_signal, length=test_signal_len) + return candidate_frames + except RuntimeError as error: + if "Kernel size can't be greater than actual input size" not in str(error): + raise + finally: + encoder.train(encoder_was_training) + + return max_candidate_frames + def transcribe( self, tokens_per_chunk: Optional[int] = None, diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 8ad906021143..cf8a1345aba6 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -14,7 +14,9 @@ import os import tempfile +from types import SimpleNamespace +import numpy as np import pytest import torch from lhotse import CutSet, MonoCut, SupervisionSegment @@ -568,6 +570,97 @@ def test_FrameBatchMultiTaskAED(self, asr_model, test_data_dir): outputs = model.transcribe() assert isinstance(outputs, Hypothesis) + @pytest.mark.unit + def test_FrameBatchMultiTaskAED_zero_pads_too_short_tail_chunk(self, monkeypatch): + class DummyPreprocessor: + def __init__(self): + self._cfg = {'window_stride': 0.01} + + def to(self, device): + return self + + class DummyPrompt: + PROMPT_LANGUAGE_SLOT = "prompt_language" + + def encode_dialog(self, turns): + return {"context_ids": [1, 2, 3]} + + class DummyTokenizer: + vocabulary = ["a", "b"] + + class DummyEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.probe = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, audio_signal, length): + return audio_signal, length + + class DummyModel: + def __init__(self): + self.timestamps_asr_model = None + self.device = torch.device("cpu") + self.decoder = None + self.tokenizer = DummyTokenizer() + self.prompt_format = "canary" + self.prompt = DummyPrompt() + self.preprocessor = SimpleNamespace(log=False) + self.encoder = DummyEncoder() + self._cfg = DictConfig( + { + 'sample_rate': 16000, + 'preprocessor': {'window_stride': 0.01, 'features': 4}, + 'encoder': {'subsampling_factor': 2}, + } + ) + + def predict_step(self, batch_input, has_processed_signal=True, timestamps=False): + raise NotImplementedError + + monkeypatch.setattr( + "nemo.collections.asr.parts.utils.streaming_utils.ASRModel.from_config_dict", + lambda cfg: DummyPreprocessor(), + ) + + asr_model = DummyModel() + model = FrameBatchMultiTaskAED(asr_model, batch_size=2) + model.min_input_frames = 3 + model.input_tokens = model.get_input_tokens( + { + 'audio_filepath': 'unused.wav', + 'duration': 100000, + 'source_lang': 'en', + 'taskname': 'asr', + 'target_lang': 'en', + 'pnc': 'yes', + 'answer': 'nothing', + } + ) + + valid_chunk = np.zeros((asr_model._cfg.preprocessor.features, 4), dtype=np.float32) + short_tail_chunk = np.zeros((asr_model._cfg.preprocessor.features, 2), dtype=np.float32) + chunk_batches = [[valid_chunk, short_tail_chunk], []] + predict_calls = [] + + def fake_get_buffers_batch(): + return chunk_batches.pop(0) + + def fake_predict_step(batch_input, has_processed_signal=True, timestamps=False): + predict_calls.append(batch_input.audio_lens.tolist()) + return [ + Hypothesis(score=0.0, y_sequence=torch.tensor([]), text="full"), + Hypothesis(score=0.0, y_sequence=torch.tensor([]), text="tail"), + ] + + monkeypatch.setattr(model.frame_bufferer, "get_buffers_batch", fake_get_buffers_batch) + monkeypatch.setattr(asr_model, "predict_step", fake_predict_step) + + outputs = model.transcribe() + + assert predict_calls == [[4, 3]] + assert outputs.text == "full tail" + assert model.chunk_offsets == [0, 4, 3] + @pytest.mark.with_downloads() @pytest.mark.unit def test_FrameBatchMultiTaskAED_with_timestamps(self, canary_1b_flash):