Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,6 @@ model:
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

test_ds:
manifest_filepath: null
is_tarred: False
tarred_audio_filepaths: null
sample_rate: 16000
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
seq_eval_mode: True
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
normalize: "per_feature"
Expand Down
1 change: 0 additions & 1 deletion examples/speechlm2/salm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lightning.pytorch import Trainer
from omegaconf import OmegaConf

from nemo.collections.common.data.fallback import FallbackDataset
from nemo.collections.speechlm2 import SALM, DataModule, SALMDataset
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
Expand Down
44 changes: 43 additions & 1 deletion nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
EndtoEndDiarizationSpeechLabel,
)
from nemo.core.classes import Dataset
from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType
from nemo.core.neural_types import (
AudioSignal,
EncodedRepresentation,
LengthsType,
NeuralType,
ProbsType,
SpectrogramType,
)
from nemo.utils import logging


Expand Down Expand Up @@ -1058,6 +1065,7 @@ def __init__(
session_len_sec: float,
num_spks: int,
featurizer,
fb_featurizer,
window_stride: float,
min_subsegment_duration: float = 0.03,
global_rank: int = 0,
Expand All @@ -1073,6 +1081,13 @@ def __init__(
round_digits=round_digits,
)
self.featurizer = featurizer
self.fb_featurizer = fb_featurizer
# STFT and subsampling factor parameters
self.n_fft = self.fb_featurizer.n_fft
self.hop_length = self.fb_featurizer.hop_length
self.stft_pad_amount = self.fb_featurizer.stft_pad_amount
self.subsampling_factor = subsampling_factor
# Annotation and target length parameters
self.round_digits = round_digits
self.feat_per_sec = int(1 / window_stride)
self.diar_frame_length = round(subsampling_factor * window_stride, round_digits)
Expand All @@ -1086,10 +1101,30 @@ def __init__(
self.round_digits = 2
self.floor_decimal = 10**self.round_digits
self.device = device
self.global_rank = global_rank

def __len__(self):
return len(self.collection)

def get_frame_count_from_time_series_length(self, seq_len):
"""
This function is used to get the sequence length of the audio signal. This is required to match
the feature frame length with ASR (STT) models. This function is copied from
NeMo/nemo/collections/asr/parts/preprocessing/features.py::FilterbankFeatures::get_seq_len.

Args:
seq_len (int):
The sequence length of the time-series data.

Returns:
seq_len (int):
The sequence length of the feature frames.
"""
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length).to(dtype=torch.long)
frame_count = int(np.ceil(seq_len / self.subsampling_factor))
return frame_count

def get_uniq_id_with_range(self, sample, deci=3):
"""
Generate unique training sample ID from unique file ID, offset and duration. The start-end time added
Expand Down Expand Up @@ -1238,10 +1273,15 @@ def __getitem__(self, index):
)
audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)]
audio_signal_length = torch.tensor(audio_signal.shape[0]).long()

# Target length should be following the ASR feature extraction convention: Use self.get_frame_count_from_time_series_length.
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate)
target_len = torch.clamp(target_len, max=self.get_frame_count_from_time_series_length(audio_signal.shape[0]))

targets = self.parse_rttm_for_targets_and_lens(
rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len
)
targets = targets[:target_len, :]
return audio_signal, audio_signal_length, targets, target_len


Expand Down Expand Up @@ -1357,6 +1397,7 @@ def __init__(
session_len_sec: float,
num_spks: int,
featurizer,
fb_featurizer,
window_stride,
global_rank: int,
soft_targets: bool,
Expand All @@ -1368,6 +1409,7 @@ def __init__(
session_len_sec=session_len_sec,
num_spks=num_spks,
featurizer=featurizer,
fb_featurizer=fb_featurizer,
window_stride=window_stride,
global_rank=global_rank,
soft_targets=soft_targets,
Expand Down
14 changes: 13 additions & 1 deletion nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
from nemo.collections.asr.parts.mixins.diarization import DiarizeConfig, SpkDiarizationMixin
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, WaveformFeaturizer
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets
from nemo.collections.asr.parts.utils.speaker_utils import generate_diarization_output_lines
Expand Down Expand Up @@ -203,6 +203,17 @@ def __setup_dataloader_from_config(self, config):
featurizer = WaveformFeaturizer(
sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor
)
fb_featurizer = FilterbankFeatures(
sample_rate=self._cfg.preprocessor.sample_rate,
normalize=self._cfg.preprocessor.normalize,
n_window_size=int(self._cfg.preprocessor.window_size * config['sample_rate']),
n_window_stride=int(self._cfg.preprocessor.window_stride * config['sample_rate']),
window=self._cfg.preprocessor.window,
nfilt=self._cfg.preprocessor.features,
n_fft=self._cfg.preprocessor.n_fft,
frame_splicing=self._cfg.preprocessor.frame_splicing,
dither=self._cfg.preprocessor.dither,
)

if 'manifest_filepath' in config and config['manifest_filepath'] is None:
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
Expand All @@ -221,6 +232,7 @@ def __setup_dataloader_from_config(self, config):
session_len_sec=config.session_len_sec,
num_spks=config.num_spks,
featurizer=featurizer,
fb_featurizer=fb_featurizer,
window_stride=self._cfg.preprocessor.window_stride,
global_rank=global_rank,
soft_targets=config.soft_targets if 'soft_targets' in config else False,
Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ def __init__(
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
super().__init__(n_window_size, n_window_stride)

self._sample_rate = sample_rate
if window_size and n_window_size:
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
Expand All @@ -255,6 +253,7 @@ def __init__(
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
super().__init__(n_window_size, n_window_stride)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes n_window_size and n_window_stride being None


# Given the long and similar argument list, point to the class and instantiate it by reference
if not use_torchaudio:
Expand Down
10 changes: 8 additions & 2 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def normalize_batch(x, seq_len, normalize_type):
torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2)
/ (x_mean_denominator.unsqueeze(1) - 1.0)
)
x_std = x_std.masked_fill(x_std.isnan(), 0.0) # edge case: only 1 frame in denominator
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
Expand Down Expand Up @@ -304,6 +305,7 @@ def __init__(
)
logging.info(f"PADDING: {pad_to}")

self.sample_rate = sample_rate
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
Expand Down Expand Up @@ -389,6 +391,7 @@ def stft(self, x):
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float, device=x.device),
return_complex=True,
pad_mode="constant",
)

def log_zero_guard_value_fn(self, x):
Expand All @@ -409,21 +412,22 @@ def log_zero_guard_value_fn(self, x):
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length)
return seq_len.to(dtype=torch.long)

@property
def filter_banks(self):
return self.fb

def forward(self, x, seq_len, linear_spec=False):
seq_len_time = seq_len
seq_len_unfixed = self.get_seq_len(seq_len)
# fix for seq_len = 0 for streaming; if size was 0, it is always padded to 1, and normalizer fails
seq_len = torch.where(seq_len == 0, torch.zeros_like(seq_len_unfixed), seq_len_unfixed)

if self.stft_pad_amount is not None:
x = torch.nn.functional.pad(
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "constant"
).squeeze(1)

# dither (only in training mode for eval determinism)
Expand All @@ -432,7 +436,9 @@ def forward(self, x, seq_len, linear_spec=False):

# do preemphasis
if self.preemph is not None:
timemask = torch.arange(x.shape[1], device=x.device).unsqueeze(0) < seq_len_time.unsqueeze(1)
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
x = x.masked_fill(~timemask, 0.0)

# disable autocast to get full range of stft values
with torch.amp.autocast(x.device.type, enabled=False):
Expand Down
Loading
Loading