Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,14 @@


class AudioFeatureIterator(IterableDataset):
def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=True):
def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=True, min_frame_len=1):
self._samples = samples
self._frame_len = frame_len
self._start = 0
self.output = True
self.count = 0
self.pad_to_frame_len = pad_to_frame_len
self.min_frame_len = min_frame_len
timestep_duration = preprocessor._cfg['window_stride']
self._feature_frame_len = frame_len / timestep_duration
audio_signal = torch.from_numpy(self._samples).unsqueeze_(0).to(device)
Expand All @@ -532,13 +533,16 @@
frame = self._features[:, self._start : last].cpu()
self._start = last
else:
if not self.pad_to_frame_len:
frame = self._features[:, self._start : self._features_len[0]].cpu()
self.output = False
segment = self._features[:, self._start : self._features_len[0]].cpu()
if segment.shape[1] == 0:
raise StopIteration
if not self.pad_to_frame_len and segment.shape[1] >= self.min_frame_len:
frame = segment
else:
frame = np.zeros([self._features.shape[0], int(self._feature_frame_len)], dtype='float32')
segment = self._features[:, self._start : self._features_len[0]].cpu()
target_frame_len = int(self._feature_frame_len) if self.pad_to_frame_len else self.min_frame_len
frame = np.zeros([self._features.shape[0], target_frame_len], dtype='float32')
frame[:, : segment.shape[1]] = segment
self.output = False
self.count += 1
return frame

Expand Down Expand Up @@ -853,7 +857,7 @@
for log_prob in log_probs:
self.all_logits.append(log_prob.cpu())
else:
del log_probs

Check notice

Code scanning / CodeQL

Mismatch between signature and use of an overridden method Note

Overridden method signature does not match
call
, where it is passed too few arguments. Overriding method
method FrameBatchMultiTaskAED.transcribe
matches the call.
Overridden method signature does not match
call
, where it is passed too few arguments. Overriding method
method FrameBatchMultiTaskAED.transcribe
matches the call.
Overridden method signature does not match
call
, where it is passed an argument named 'timestamps'. Overriding method
method FrameBatchMultiTaskAED.transcribe
matches the call.
del encoded_len
del predictions

Expand Down Expand Up @@ -1811,6 +1815,8 @@
super().__init__(asr_model, frame_len, total_buffer, batch_size, pad_to_buffer_len=False)
self.window_stride = asr_model._cfg.preprocessor.window_stride
self.subsampling_factor = asr_model._cfg.encoder.subsampling_factor
self.min_input_frames = None
self.timestamps_min_input_frames = None
self.chunk_offsets = [
0,
] # chunk offsets in terms of num frames before subsampling
Expand Down Expand Up @@ -1883,9 +1889,17 @@
self.input_tokens = self.get_input_tokens(meta_data)
samples = get_samples(audio_filepath)
padded_samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate)))
min_frame_len = self._get_min_input_frames()
if timestamps and self.timestamps_asr_model is not None:
min_frame_len = max(min_frame_len, self._get_timestamps_min_input_frames())

frame_reader = AudioFeatureIterator(
padded_samples, self.frame_len, self.raw_preprocessor, self.asr_model.device, pad_to_frame_len=False
padded_samples,
self.frame_len,
self.raw_preprocessor,
self.asr_model.device,
pad_to_frame_len=False,
min_frame_len=min_frame_len,
)
self.set_frame_reader(frame_reader)
if timestamps and self.timestamps_asr_model is not None:
Expand All @@ -1901,6 +1915,7 @@
self.frame_len,
self.timestamps_frame_asr.raw_preprocessor,
self.timestamps_frame_asr.asr_model.device,
min_frame_len=min_frame_len,
)
self.timestamps_frame_asr.set_frame_reader(ts_model_frame_reader)

Expand All @@ -1909,6 +1924,19 @@
device = self.asr_model.device
for batch in iter(self.data_loader):
feat_signal, feat_signal_len = batch
min_input_frames = self._get_min_input_frames()
short_chunk_mask = feat_signal_len < min_input_frames
if short_chunk_mask.any():
short_chunk_lengths = feat_signal_len[short_chunk_mask].tolist()
logging.warning(
"Zero-padding %d chunk(s) shorter than the minimum encoder input length (%d frames): %s",
len(short_chunk_lengths),
min_input_frames,
short_chunk_lengths,
)
if feat_signal.size(-1) < min_input_frames:
feat_signal = torch.nn.functional.pad(feat_signal, (0, min_input_frames - feat_signal.size(-1)))
feat_signal_len = feat_signal_len.clamp(min=min_input_frames)
# keep track of chunk offsets
self.chunk_offsets.extend(feat_signal_len.tolist())
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)
Expand All @@ -1930,6 +1958,40 @@
self.all_preds.extend(predictions)
del predictions

def _get_min_input_frames(self):
if self.min_input_frames is None:
self.min_input_frames = self._estimate_min_input_frames(self.asr_model)
return self.min_input_frames

def _get_timestamps_min_input_frames(self):
if self.timestamps_min_input_frames is None:
self.timestamps_min_input_frames = self._estimate_min_input_frames(self.timestamps_asr_model)
return self.timestamps_min_input_frames

@staticmethod
def _estimate_min_input_frames(asr_model, max_candidate_frames=64):
encoder = asr_model.encoder
feat_in = asr_model._cfg.preprocessor.features
encoder_dtype = next(encoder.parameters()).dtype
encoder_device = asr_model.device
encoder_was_training = encoder.training

try:
encoder.eval()
for candidate_frames in range(1, max_candidate_frames + 1):
test_signal = torch.zeros(1, feat_in, candidate_frames, device=encoder_device, dtype=encoder_dtype)
test_signal_len = torch.tensor([candidate_frames], device=encoder_device)
try:
encoder(audio_signal=test_signal, length=test_signal_len)
return candidate_frames
except RuntimeError as error:
if "Kernel size can't be greater than actual input size" not in str(error):
raise
finally:
encoder.train(encoder_was_training)

return max_candidate_frames

def transcribe(
self,
tokens_per_chunk: Optional[int] = None,
Expand Down
93 changes: 93 additions & 0 deletions tests/collections/asr/test_asr_multitask_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import os
import tempfile
from types import SimpleNamespace

import numpy as np
import pytest
import torch
from lhotse import CutSet, MonoCut, SupervisionSegment
Expand Down Expand Up @@ -568,6 +570,97 @@ def test_FrameBatchMultiTaskAED(self, asr_model, test_data_dir):
outputs = model.transcribe()
assert isinstance(outputs, Hypothesis)

@pytest.mark.unit
def test_FrameBatchMultiTaskAED_zero_pads_too_short_tail_chunk(self, monkeypatch):
class DummyPreprocessor:
def __init__(self):
self._cfg = {'window_stride': 0.01}

def to(self, device):
return self

class DummyPrompt:
PROMPT_LANGUAGE_SLOT = "prompt_language"

def encode_dialog(self, turns):
return {"context_ids": [1, 2, 3]}

class DummyTokenizer:
vocabulary = ["a", "b"]

class DummyEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.probe = torch.nn.Parameter(torch.zeros(1))

def forward(self, audio_signal, length):
return audio_signal, length

class DummyModel:
def __init__(self):
self.timestamps_asr_model = None
self.device = torch.device("cpu")
self.decoder = None
self.tokenizer = DummyTokenizer()
self.prompt_format = "canary"
self.prompt = DummyPrompt()
self.preprocessor = SimpleNamespace(log=False)
self.encoder = DummyEncoder()
self._cfg = DictConfig(
{
'sample_rate': 16000,
'preprocessor': {'window_stride': 0.01, 'features': 4},
'encoder': {'subsampling_factor': 2},
}
)

def predict_step(self, batch_input, has_processed_signal=True, timestamps=False):
raise NotImplementedError

monkeypatch.setattr(
"nemo.collections.asr.parts.utils.streaming_utils.ASRModel.from_config_dict",
lambda cfg: DummyPreprocessor(),
)

asr_model = DummyModel()
model = FrameBatchMultiTaskAED(asr_model, batch_size=2)
model.min_input_frames = 3
model.input_tokens = model.get_input_tokens(
{
'audio_filepath': 'unused.wav',
'duration': 100000,
'source_lang': 'en',
'taskname': 'asr',
'target_lang': 'en',
'pnc': 'yes',
'answer': 'nothing',
}
)

valid_chunk = np.zeros((asr_model._cfg.preprocessor.features, 4), dtype=np.float32)
short_tail_chunk = np.zeros((asr_model._cfg.preprocessor.features, 2), dtype=np.float32)
chunk_batches = [[valid_chunk, short_tail_chunk], []]
predict_calls = []

def fake_get_buffers_batch():
return chunk_batches.pop(0)

def fake_predict_step(batch_input, has_processed_signal=True, timestamps=False):
predict_calls.append(batch_input.audio_lens.tolist())
return [
Hypothesis(score=0.0, y_sequence=torch.tensor([]), text="full"),
Hypothesis(score=0.0, y_sequence=torch.tensor([]), text="tail"),
]

monkeypatch.setattr(model.frame_bufferer, "get_buffers_batch", fake_get_buffers_batch)
monkeypatch.setattr(asr_model, "predict_step", fake_predict_step)

outputs = model.transcribe()

assert predict_calls == [[4, 3]]
assert outputs.text == "full tail"
assert model.chunk_offsets == [0, 4, 3]

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_FrameBatchMultiTaskAED_with_timestamps(self, canary_1b_flash):
Expand Down
Loading