diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index 66cfc5fd1b61..2e85c5dc73b7 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -67,29 +67,6 @@ model: window_stride: ${model.preprocessor.window_stride} subsampling_factor: ${model.encoder.subsampling_factor} - test_ds: - manifest_filepath: null - is_tarred: False - tarred_audio_filepaths: null - sample_rate: 16000 - num_spks: ${model.max_num_of_spks} - session_len_sec: 90 # Maximum session length in seconds - soft_label_thres: 0.5 - soft_targets: False - labels: null - batch_size: ${batch_size} - shuffle: False - seq_eval_mode: True - num_workers: ${num_workers} - validation_mode: True - # lhotse config - use_lhotse: False - use_bucketing: False - drop_last: False - pin_memory: True - window_stride: ${model.preprocessor.window_stride} - subsampling_factor: ${model.encoder.subsampling_factor} - preprocessor: _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor normalize: "per_feature" diff --git a/examples/speechlm2/salm_train.py b/examples/speechlm2/salm_train.py index c10638056bdf..20ba5597bc9e 100644 --- a/examples/speechlm2/salm_train.py +++ b/examples/speechlm2/salm_train.py @@ -17,7 +17,6 @@ from lightning.pytorch import Trainer from omegaconf import OmegaConf -from nemo.collections.common.data.fallback import FallbackDataset from nemo.collections.speechlm2 import SALM, DataModule, SALMDataset from nemo.core.config import hydra_runner from nemo.utils.exp_manager import exp_manager diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index 817938b758ae..25022c0a5d91 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -27,7 +27,14 @@ EndtoEndDiarizationSpeechLabel, ) from nemo.core.classes import Dataset -from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType +from nemo.core.neural_types import ( + AudioSignal, + EncodedRepresentation, + LengthsType, + NeuralType, + ProbsType, + SpectrogramType, +) from nemo.utils import logging @@ -1058,6 +1065,7 @@ def __init__( session_len_sec: float, num_spks: int, featurizer, + fb_featurizer, window_stride: float, min_subsegment_duration: float = 0.03, global_rank: int = 0, @@ -1073,6 +1081,13 @@ def __init__( round_digits=round_digits, ) self.featurizer = featurizer + self.fb_featurizer = fb_featurizer + # STFT and subsampling factor parameters + self.n_fft = self.fb_featurizer.n_fft + self.hop_length = self.fb_featurizer.hop_length + self.stft_pad_amount = self.fb_featurizer.stft_pad_amount + self.subsampling_factor = subsampling_factor + # Annotation and target length parameters self.round_digits = round_digits self.feat_per_sec = int(1 / window_stride) self.diar_frame_length = round(subsampling_factor * window_stride, round_digits) @@ -1086,10 +1101,30 @@ def __init__( self.round_digits = 2 self.floor_decimal = 10**self.round_digits self.device = device + self.global_rank = global_rank def __len__(self): return len(self.collection) + def get_frame_count_from_time_series_length(self, seq_len): + """ + This function is used to get the sequence length of the audio signal. This is required to match + the feature frame length with ASR (STT) models. This function is copied from + NeMo/nemo/collections/asr/parts/preprocessing/features.py::FilterbankFeatures::get_seq_len. + + Args: + seq_len (int): + The sequence length of the time-series data. + + Returns: + seq_len (int): + The sequence length of the feature frames. + """ + pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 + seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length).to(dtype=torch.long) + frame_count = int(np.ceil(seq_len / self.subsampling_factor)) + return frame_count + def get_uniq_id_with_range(self, sample, deci=3): """ Generate unique training sample ID from unique file ID, offset and duration. The start-end time added @@ -1238,10 +1273,15 @@ def __getitem__(self, index): ) audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] audio_signal_length = torch.tensor(audio_signal.shape[0]).long() + + # Target length should be following the ASR feature extraction convention: Use self.get_frame_count_from_time_series_length. target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + target_len = torch.clamp(target_len, max=self.get_frame_count_from_time_series_length(audio_signal.shape[0])) + targets = self.parse_rttm_for_targets_and_lens( rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len ) + targets = targets[:target_len, :] return audio_signal, audio_signal_length, targets, target_len @@ -1357,6 +1397,7 @@ def __init__( session_len_sec: float, num_spks: int, featurizer, + fb_featurizer, window_stride, global_rank: int, soft_targets: bool, @@ -1368,6 +1409,7 @@ def __init__( session_len_sec=session_len_sec, num_spks=num_spks, featurizer=featurizer, + fb_featurizer=fb_featurizer, window_stride=window_stride, global_rank=global_rank, soft_targets=soft_targets, diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 5eef440c98ab..63b6db61bbda 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -34,7 +34,7 @@ from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy from nemo.collections.asr.models.asr_model import ExportableEncDecModel from nemo.collections.asr.parts.mixins.diarization import DiarizeConfig, SpkDiarizationMixin -from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, WaveformFeaturizer from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets from nemo.collections.asr.parts.utils.speaker_utils import generate_diarization_output_lines @@ -203,6 +203,17 @@ def __setup_dataloader_from_config(self, config): featurizer = WaveformFeaturizer( sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor ) + fb_featurizer = FilterbankFeatures( + sample_rate=self._cfg.preprocessor.sample_rate, + normalize=self._cfg.preprocessor.normalize, + n_window_size=int(self._cfg.preprocessor.window_size * config['sample_rate']), + n_window_stride=int(self._cfg.preprocessor.window_stride * config['sample_rate']), + window=self._cfg.preprocessor.window, + nfilt=self._cfg.preprocessor.features, + n_fft=self._cfg.preprocessor.n_fft, + frame_splicing=self._cfg.preprocessor.frame_splicing, + dither=self._cfg.preprocessor.dither, + ) if 'manifest_filepath' in config and config['manifest_filepath'] is None: logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") @@ -221,6 +232,7 @@ def __setup_dataloader_from_config(self, config): session_len_sec=config.session_len_sec, num_spks=config.num_spks, featurizer=featurizer, + fb_featurizer=fb_featurizer, window_stride=self._cfg.preprocessor.window_stride, global_rank=global_rank, soft_targets=config.soft_targets if 'soft_targets' in config else False, diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index a5bab42331ac..085fd0e63183 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -242,8 +242,6 @@ def __init__( stft_exact_pad=False, # Deprecated arguments; kept for config compatibility stft_conv=False, # Deprecated arguments; kept for config compatibility ): - super().__init__(n_window_size, n_window_stride) - self._sample_rate = sample_rate if window_size and n_window_size: raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") @@ -255,6 +253,7 @@ def __init__( n_window_size = int(window_size * self._sample_rate) if window_stride: n_window_stride = int(window_stride * self._sample_rate) + super().__init__(n_window_size, n_window_stride) # Given the long and similar argument list, point to the class and instantiate it by reference if not use_torchaudio: diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index adbc1b6c97c1..26472792cdaa 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -87,6 +87,7 @@ def normalize_batch(x, seq_len, normalize_type): torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2) / (x_mean_denominator.unsqueeze(1) - 1.0) ) + x_std = x_std.masked_fill(x_std.isnan(), 0.0) # edge case: only 1 frame in denominator # make sure x_std is not zero x_std += CONSTANT return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std @@ -304,6 +305,7 @@ def __init__( ) logging.info(f"PADDING: {pad_to}") + self.sample_rate = sample_rate self.win_length = n_window_size self.hop_length = n_window_stride self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) @@ -389,6 +391,7 @@ def stft(self, x): center=False if self.exact_pad else True, window=self.window.to(dtype=torch.float, device=x.device), return_complex=True, + pad_mode="constant", ) def log_zero_guard_value_fn(self, x): @@ -409,7 +412,7 @@ def log_zero_guard_value_fn(self, x): def get_seq_len(self, seq_len): # Assuming that center is True is stft_pad_amount = 0 pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 - seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1 + seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) return seq_len.to(dtype=torch.long) @property @@ -417,13 +420,14 @@ def filter_banks(self): return self.fb def forward(self, x, seq_len, linear_spec=False): + seq_len_time = seq_len seq_len_unfixed = self.get_seq_len(seq_len) # fix for seq_len = 0 for streaming; if size was 0, it is always padded to 1, and normalizer fails seq_len = torch.where(seq_len == 0, torch.zeros_like(seq_len_unfixed), seq_len_unfixed) if self.stft_pad_amount is not None: x = torch.nn.functional.pad( - x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" + x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "constant" ).squeeze(1) # dither (only in training mode for eval determinism) @@ -432,7 +436,9 @@ def forward(self, x, seq_len, linear_spec=False): # do preemphasis if self.preemph is not None: + timemask = torch.arange(x.shape[1], device=x.device).unsqueeze(0) < seq_len_time.unsqueeze(1) x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) + x = x.masked_fill(~timemask, 0.0) # disable autocast to get full range of stft values with torch.amp.autocast(x.device.type, enabled=False): diff --git a/nemo/collections/asr/parts/submodules/subsampling.py b/nemo/collections/asr/parts/submodules/subsampling.py index 068cd36022b0..fee6204bd07b 100644 --- a/nemo/collections/asr/parts/submodules/subsampling.py +++ b/nemo/collections/asr/parts/submodules/subsampling.py @@ -66,7 +66,7 @@ class ConvSubsampling(torch.nn.Module): Args: subsampling (str): The subsampling technique from {"vggnet", "striding", "dw-striding"} subsampling_factor (int): The subsampling factor which should be a power of 2 - subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) + subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 feat_in (int): size of the input features feat_out (int): size of the output features @@ -374,7 +374,7 @@ def __init__( else: raise ValueError(f"Not valid sub-sampling: {subsampling}!") - self.conv = torch.nn.Sequential(*layers) + self.conv = MaskedConvSequential(*layers) def get_sampling_frames(self): return [1, self.subsampling_factor] @@ -383,7 +383,7 @@ def get_streaming_cache_size(self): return [0, self.subsampling_factor + 1] def forward(self, x, lengths): - lengths = calc_length( + out_lengths = calc_length( lengths, all_paddings=self._left_padding + self._right_padding, kernel_size=self._kernel_size, @@ -392,11 +392,8 @@ def forward(self, x, lengths): repeat_num=self._sampling_num, ) - # Unsqueeze Channel Axis - if self.conv2d_subsampling: - x = x.unsqueeze(1) # Transpose to Channel First mode - else: + if not self.conv2d_subsampling: x = x.transpose(1, 2) # split inputs if chunking_factor is set @@ -405,7 +402,7 @@ def forward(self, x, lengths): # if subsampling_conv_chunking_factor is 1, we split only if needed # avoiding a bug / feature limiting indexing of tensors to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = 2 ** 31 / self._conv_channels * self._stride * self._stride + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride if torch.numel(x) > x_ceil: need_to_split = True else: @@ -415,16 +412,18 @@ def forward(self, x, lengths): need_to_split = True if need_to_split: - x, success = self.conv_split_by_batch(x) + x, lengths, success = self.conv_split_by_batch(x, lengths) if not success: # if unable to split by batch, try by channel if self._subsampling == 'dw_striding': + # TODO: implement lengths inside conv_split_by_channel x = self.conv_split_by_channel(x) + lengths = out_lengths else: - x = self.conv(x) # try anyway + x, lengths = self.conv(x, lengths) # try anyway else: - x = self.conv(x) + x, lengths = self.conv(x, lengths) else: - x = self.conv(x) + x, lengths = self.conv(x) # Flatten Channel and Frequency Axes if self.conv2d_subsampling: @@ -442,8 +441,8 @@ def reset_parameters(self): with torch.no_grad(): # init conv scale = 1.0 / self._kernel_size - dw_max = (self._kernel_size ** 2) ** -0.5 - pw_max = self._conv_channels ** -0.5 + dw_max = (self._kernel_size**2) ** -0.5 + pw_max = self._conv_channels**-0.5 torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) @@ -459,11 +458,11 @@ def reset_parameters(self): torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) - def conv_split_by_batch(self, x): - """ Tries to split input by batch, run conv and concat results """ - b, _, _, _ = x.size() + def conv_split_by_batch(self, x, lengths): + """Tries to split input by batch, run conv and concat results""" + b, *_ = x.size() if b == 1: # can't split if batch size is 1 - return x, False + return x, lengths, False if self.subsampling_conv_chunking_factor > 1: cf = self.subsampling_conv_chunking_factor @@ -471,20 +470,31 @@ def conv_split_by_batch(self, x): else: # avoiding a bug / feature limiting indexing of tensors to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = 2 ** 31 / self._conv_channels * self._stride * self._stride + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) - cf = 2 ** p + cf = 2**p logging.debug(f'using auto set chunking factor: {cf}') new_batch_size = b // cf if new_batch_size == 0: # input is too big - return x, False + return x, lengths, False logging.debug(f'conv subsampling: using split batch size {new_batch_size}') - return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True + + ans = [ + self.conv(chunk, ln) + for chunk, ln in zip( + torch.split(x, new_batch_size, 0), + torch.split(lengths, new_batch_size, 0), + ) + ] + return torch.cat([a[0] for a in ans]), torch.cat([a[1] for a in ans]), True def conv_split_by_channel(self, x): - """ For dw convs, tries to split input by time, run conv and concat results """ + """For dw convs, tries to split input by time, run conv and concat results""" + + # Note: this method doesn't use the convolution masking implemented in MaskedConvolutionSequential + x = x.unsqueeze(0) x = self.conv[0](x) # full conv2D x = self.conv[1](x) # activation @@ -497,8 +507,8 @@ def conv_split_by_channel(self, x): else: # avoiding a bug / feature limiting indexing of tensors to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 - p = math.ceil(math.log(torch.numel(x) / 2 ** 31, 2)) - cf = 2 ** p + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p logging.debug(f'using auto set chunking factor: {cf}') new_c = int(c // cf) @@ -520,7 +530,7 @@ def conv_split_by_channel(self, x): return x def channel_chunked_conv(self, conv, chunk_size, x): - """ Performs channel chunked convolution""" + """Performs channel chunked convolution""" ind = 0 out_chunks = [] @@ -564,7 +574,7 @@ def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_fact def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): - """ Calculates the output length of a Tensor passed through a convolution or max pooling layer""" + """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" add_pad: float = all_paddings - kernel_size one: float = 1.0 for i in range(repeat_num): @@ -606,7 +616,12 @@ def __init__(self, d_model: int, out_dim: int, kernel_size: int = 5, stride: int ) self.pw_conv = nn.Conv1d( - in_channels=d_model, out_channels=out_dim, kernel_size=1, stride=1, padding=0, groups=1, + in_channels=d_model, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) self.reset_parameters() @@ -631,8 +646,8 @@ def forward(self, x, att_mask=None, pad_mask=None): return x, att_mask, pad_mask def reset_parameters(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.d_model ** -0.5 + dw_max = self.kernel_size**-0.5 + pw_max = self.d_model**-0.5 with torch.no_grad(): torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) @@ -671,8 +686,8 @@ def __init__(self, reduction: str, d_model: int, reduction_factor: int = 2): def forward(self, x, lengths): """Shapes: - - x: [B, T, C] - - lengths: [B] + - x: [B, T, C] + - lengths: [B] """ if self.reduction == 'striding': @@ -691,3 +706,54 @@ def forward(self, x, lengths): x = torch.transpose(x, 1, 2) # [B, T, C] return x, lengths + + +def apply_channel_mask(tensor, mask): + """Apply mask to tensor with channel dimension.""" + # tensor: (batch, channels, time, features) + # mask: (batch, time, features) + batch_size, channels, time, features = tensor.shape + expanded_mask = mask.unsqueeze(1).expand(batch_size, channels, time, features) + return tensor * expanded_mask + + +def calculate_conv_output_size(input_size: torch.Tensor, kernel_size: int, stride: int, padding: tuple[int, int]): + """Calculate exact output size after convolution.""" + return (input_size + padding[0] + padding[1] - kernel_size) // stride + 1 + + +class MaskedConvSequential(nn.Sequential): + def forward(self, x, lengths): + # Convert input (batch, time, features) to conv format + x = x.unsqueeze(1) # (batch, 1, time, features) + current_lengths = lengths.clone().float() + mask = self._create_mask(x, current_lengths.long()) + + # Process through each layer with mask propagation + for i, layer in enumerate(self): + # Apply current mask before layer + x = apply_channel_mask(x, mask) + + # Apply layer + x = layer(x) + + # Update lengths for stride operations with proper padding + if hasattr(layer, 'stride') and layer.stride != (1, 1): + if hasattr(layer, "_left_padding"): + padding = (layer._left_padding, layer._right_padding) # CausalConv2D + else: + padding = layer.padding + current_lengths = calculate_conv_output_size( + current_lengths, layer.kernel_size[0], layer.stride[0], padding + ) + mask = self._create_mask(x, current_lengths.long()) + + # Final masking + x = apply_channel_mask(x, mask) + return x, current_lengths.long() + + def _create_mask(self, tensor, lengths): + """Create mask matching tensor dimensions.""" + batch_size, channels, time, features = tensor.shape + time_mask = torch.arange(time, device=tensor.device).expand(batch_size, time) < lengths.unsqueeze(1) + return time_mask.unsqueeze(-1).expand(batch_size, time, features).to(tensor.dtype) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 6a503d4cd5b7..009a93b18d95 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -299,6 +299,9 @@ def get_hidden_length_from_sample_length( This function computes the number of frames required for a given number of audio samples, considering the number of samples per mel frame and the number of mel frames per ASR frame. + Please refer to the following function for more on feature frame length calculation: + NeMo/nemo/collections/asr/parts/preprocessing/features.py::FilterbankFeatures::get_seq_len + Parameters: num_samples (int): The total number of audio samples. num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. @@ -307,7 +310,7 @@ def get_hidden_length_from_sample_length( Returns: hidden_length (int): The calculated hidden length in terms of the number of frames. """ - mel_frame_count = math.ceil((num_samples + 1) / num_sample_per_mel_frame) + mel_frame_count = math.ceil(num_samples / num_sample_per_mel_frame) hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index cac1eb2fcdf3..844ec49a8771 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -822,7 +822,7 @@ def test_constructor_pretrained_rnnt(self): def test_asr_model_adapter_loss(self, model): original_num_params = model.num_weights x = torch.randn(2, 512) - x_len = torch.tensor([256, 512], dtype=torch.int32) + x_len = torch.tensor([400, 512], dtype=torch.int32) adapter_cfg = get_adapter_cfg() # type: adapter_modules.LinearAdapterConfig adapter_cfg.adapter_strategy.l2_lambda = 0.01 diff --git a/tests/collections/asr/mixins/test_transcription.py b/tests/collections/asr/mixins/test_transcription.py index 004b74ee252f..6a02cc233ced 100644 --- a/tests/collections/asr/mixins/test_transcription.py +++ b/tests/collections/asr/mixins/test_transcription.py @@ -44,6 +44,41 @@ def forward(self, x): return out +class DummyDatasetAudioOnly(Dataset): + def __init__(self, audio_files: List[str], config: Dict): + self.audio_files = audio_files + self.config = config + + def __getitem__(self, index): + data = self.audio_files[index] + data = torch.tensor([float(data)]).view(1) + return data + + def __len__(self): + return len(self.audio_files) + + +class DummyDataset(Dataset): + def __init__(self, audio_tensors: List[str], config: Dict = None): + self.audio_tensors = audio_tensors + self.config = config + + def __getitem__(self, index): + data = self.audio_tensors[index] + samples = torch.tensor(data) + # Calculate seq length + seq_len = torch.tensor(samples.shape[0], dtype=torch.long) + + # Dummy text tokens + text_tokens = torch.tensor([0], dtype=torch.long) + text_tokens_len = torch.tensor(1, dtype=torch.long) + + return (samples, seq_len, text_tokens, text_tokens_len) + + def __len__(self): + return len(self.audio_tensors) + + @pytest.mark.with_downloads() @pytest.fixture() def audio_files(test_data_dir): @@ -85,20 +120,7 @@ def _transcribe_input_manifest_processing(self, audio_files: List[str], temp_dir return ds_config def _setup_transcribe_dataloader(self, config: Dict) -> DataLoader: - class DummyDataset(Dataset): - def __init__(self, audio_files: List[str], config: Dict): - self.audio_files = audio_files - self.config = config - - def __getitem__(self, index): - data = self.audio_files[index] - data = torch.tensor([float(data)]).view(1) - return data - - def __len__(self): - return len(self.audio_files) - - dataset = DummyDataset(config['paths2audio_files'], config) + dataset = DummyDatasetAudioOnly(config['paths2audio_files'], config) return DataLoader( dataset=dataset, @@ -139,27 +161,6 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): self.flag_end = True -class DummyDataset(Dataset): - def __init__(self, audio_tensors: List[str], config: Dict = None): - self.audio_tensors = audio_tensors - self.config = config - - def __getitem__(self, index): - data = self.audio_tensors[index] - samples = torch.tensor(data) - # Calculate seq length - seq_len = torch.tensor(samples.shape[0], dtype=torch.long) - - # Dummy text tokens - text_tokens = torch.tensor([0], dtype=torch.long) - text_tokens_len = torch.tensor(1, dtype=torch.long) - - return (samples, seq_len, text_tokens, text_tokens_len) - - def __len__(self): - return len(self.audio_tensors) - - @pytest.fixture() def dummy_model(): return TranscribableDummy() @@ -470,8 +471,8 @@ def test_timestamps_with_transcribe_hybrid_ctc_head(self, audio_files, fast_conf # check hypothesis object assert isinstance(output[0], Hypothesis) # check transcript - assert output[0].text == 'Stop' - assert output[1].text == 'Start.' + assert output[0].text in ['Stop', 'Stop?'] + assert output[1].text in ['Start', 'Start.'] # check timestamp assert output[0].timestamp['segment'][0]['start'] == pytest.approx(0.4) diff --git a/tests/collections/asr/test_asr_classification_model.py b/tests/collections/asr/test_asr_classification_model.py index f41c36219142..87ab3d73c1ea 100644 --- a/tests/collections/asr/test_asr_classification_model.py +++ b/tests/collections/asr/test_asr_classification_model.py @@ -142,7 +142,7 @@ def test_forward(self, speech_classification_model): asr_model.preprocessor.featurizer.pad_to = 0 input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_context_biasing.py b/tests/collections/asr/test_asr_context_biasing.py index 9b9ee363e6be..0ae3cd905350 100644 --- a/tests/collections/asr/test_asr_context_biasing.py +++ b/tests/collections/asr/test_asr_context_biasing.py @@ -91,7 +91,7 @@ def test_run_word_spotter(self, test_data_dir, conformer_ctc_bpe_model): assert ws_results[0].word == target_text assert ws_results[0].start_frame == 9 assert ws_results[0].end_frame == 19 - assert round(ws_results[0].score, 4) == 8.9967 + torch.testing.assert_close(ws_results[0].score, 8.9967, atol=1e-3, rtol=1e-4) class TestContextBiasingUtils: diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index eac5041de2b3..fc2ee79d8bae 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -100,7 +100,7 @@ def test_forward(self, asr_model): asr_model.preprocessor.featurizer.pad_to = 0 input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index ae131abd3d48..416996dc5c13 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -114,7 +114,7 @@ def test_forward(self, asr_model): asr_model.preprocessor.featurizer.pad_to = 0 input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index d5c5be8b44ad..ab792889dcfb 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -375,6 +375,9 @@ def test_dali_bpe_dataset(self, test_data_dir): for og_transcript, shuffled_transcript in zip(sorted(original_transcripts), sorted(shuffled_transcripts)): assert og_transcript == shuffled_transcript + @pytest.mark.xfail( + reason="DALI ASR Dataset's preprocessor is not patched with padding inconsistency fix (PR #13827)" + ) @pytest.mark.skipif(not HAVE_DALI, reason="NVIDIA DALI is not installed or incompatible version") @pytest.mark.unit def test_dali_char_vs_ref_dataset(self, test_data_dir): diff --git a/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py b/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py index 24ba5849850f..8052c59d7a05 100644 --- a/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py +++ b/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py @@ -27,7 +27,7 @@ def test_seq_len(self): test_1 = torch.randn(1, 800) test_1_len = torch.tensor([800]) fb_spec, fb_len = fb_module(test_1, test_1_len) - assert fb_spec.shape[2] == fb_len[0], f"{fb_spec.shape} != {fb_len}" + assert fb_spec.shape[2] - 1 == fb_len[0], f"{fb_spec.shape} != {fb_len}" librosa_spec = librosa.stft(test_1.cpu().detach().numpy().squeeze(), n_fft=512, hop_length=160, win_length=320) assert librosa_spec.shape[1] == fb_spec.shape[2], f"{librosa_spec.shape} != {fb_spec.shape}" @@ -46,12 +46,12 @@ def test_random_stft_sizes(self): n_window_stride=hop_size, normalize=False, ) - audio_length = np.random.randint(nfft, 2 ** 16) + audio_length = np.random.randint(nfft, 2**16) test_1 = torch.randn(1, audio_length) test_1_len = torch.tensor([audio_length]) fb_spec, fb_len = fb_module(test_1, test_1_len) assert ( - fb_spec.shape[2] == fb_len[0] + fb_spec.shape[2] - 1 == fb_len[0] ), f"{fb_spec.shape} != {fb_len}: {nfft}, {window_size}, {hop_size}, {audio_length}" librosa_spec = librosa.stft( @@ -78,17 +78,23 @@ def test_random_stft_sizes_exact_pad(self): n_window_stride=hop_size, normalize=False, ) - audio_length = np.random.randint(nfft, 2 ** 16) + audio_length = np.random.randint(nfft, 2**16) test_1 = torch.randn(1, audio_length) test_1_len = torch.tensor([audio_length]) fb_spec, fb_len = fb_module(test_1, test_1_len) assert ( - fb_spec.shape[2] == fb_len[0] + fb_spec.shape[2] - 1 == fb_len[0] ), f"{fb_spec.shape} != {fb_len}: {nfft}, {window_size}, {hop_size}, {audio_length}" test_2 = test_1.cpu().detach().numpy().squeeze() test_2 = np.pad(test_2, int((nfft - hop_size) // 2), mode="reflect") - librosa_spec = librosa.stft(test_2, n_fft=nfft, hop_length=hop_size, win_length=window_size, center=False,) + librosa_spec = librosa.stft( + test_2, + n_fft=nfft, + hop_length=hop_size, + win_length=window_size, + center=False, + ) assert ( fb_spec.shape[2] == librosa_spec.shape[1] diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index c75de6064e51..40f5d2ab4f68 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -149,7 +149,7 @@ def test_forward(self, hybrid_asr_model): hybrid_asr_model.compute_eval_loss = False input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 456d7450eeba..a8639aeb6b3e 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -147,7 +147,7 @@ def test_forward(self, hybrid_asr_model): hybrid_asr_model.compute_eval_loss = False input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_interctc_models.py b/tests/collections/asr/test_asr_interctc_models.py index a8d7101033ab..6f643cf4baaf 100644 --- a/tests/collections/asr/test_asr_interctc_models.py +++ b/tests/collections/asr/test_asr_interctc_models.py @@ -203,7 +203,7 @@ def __getitem__(self, idx): # processed signal directly initially to remove the chance of # this edge-case input_signal = torch.randn(size=(1, 512)) - input_length = torch.randint(low=161, high=500, size=[1]) + input_length = torch.randint(low=321, high=500, size=[1]) target = torch.randint(size=(1, input_length[0]), low=0, high=28) target_length = torch.tensor([input_length[0]]) diff --git a/tests/collections/asr/test_asr_modules.py b/tests/collections/asr/test_asr_modules.py index 7f61483e95de..d973b4451d25 100644 --- a/tests/collections/asr/test_asr_modules.py +++ b/tests/collections/asr/test_asr_modules.py @@ -44,7 +44,7 @@ def test_AudioToMelSpectrogramPreprocessor_batch(self): # Ensure that the two functions behave similarily for _ in range(10): - input_signal, length = instance1.input_example(4, 512, 161) + input_signal, length = instance1.input_example(4, 512, 321) with torch.no_grad(): # batch size 1 @@ -87,7 +87,7 @@ def test_SpectrogramAugmentationr_legacy(self): # Make sure forward doesn't throw with expected input instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) - input_signal, length = instance0.input_example(4, 512, 161) + input_signal, length = instance0.input_example(4, 512, 321) res0 = instance0(input_signal=input_signal, length=length) res = instance1(input_spec=res0[0], length=length) @@ -104,7 +104,7 @@ def test_SpectrogramAugmentationr_vectorized(self): # Make sure forward doesn't throw with expected input instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) - input_signal, length = instance0.input_example(4, 512, 161) + input_signal, length = instance0.input_example(4, 512, 321) res0 = instance0(input_signal=input_signal, length=length) res = instance1(input_spec=res0[0], length=length) @@ -128,7 +128,7 @@ def test_SpectrogramAugmentationr_numba_kernel(self, caplog): # Make sure forward doesn't throw with expected input instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) - input_signal, length = instance0.input_example(8, 512, 161) + input_signal, length = instance0.input_example(8, 512, 321) res0 = instance0(input_signal=input_signal, length=length) res = instance1(input_spec=res0[0], length=length) @@ -162,7 +162,7 @@ def test_CropOrPadSpectrogramAugmentation(self): # Make sure forward doesn't throw with expected input instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) - input_signal, length = instance0.input_example(4, 512, 161) + input_signal, length = instance0.input_example(4, 512, 321) res0 = instance0(input_signal=input_signal, length=length) res, new_length = instance1(input_signal=res0[0], length=length) @@ -191,7 +191,7 @@ def test_MaskedPatchAugmentation(self): # Make sure forward doesn't throw with expected input instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) - input_signal, length = instance0.input_example(4, 512, 161) + input_signal, length = instance0.input_example(4, 512, 321) res0 = instance0(input_signal=input_signal, length=length) res = instance1(input_spec=res0[0], length=length) diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 4947c9db462c..4ddb127a05b5 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -169,7 +169,7 @@ def test_forward(self, asr_model): asr_model.compute_eval_loss = False input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) targets = torch.randint(low=0, high=100, size=[4, 10]) targets_len = torch.randint(low=1, high=10, size=[4]) @@ -184,7 +184,6 @@ def test_forward(self, asr_model): transcript=targets[i : i + 1], transcript_length=targets_len[i : i + 1], ) - print(log_probs.shape) logprobs_instance.append(log_probs) logits_instance = torch.cat(logprobs_instance, 0) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 07c6adf761ba..ad833f57dbe7 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -279,7 +279,7 @@ def test_forward(self, asr_model): asr_model.compute_eval_loss = False input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index be86d5bffbb2..561c9ddfbc92 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -161,7 +161,7 @@ def test_forward(self, asr_model): asr_model.compute_eval_loss = False input_signal = torch.randn(size=(4, 512)) - length = torch.randint(low=161, high=500, size=[4]) + length = torch.randint(low=321, high=500, size=[4]) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/asr/test_asr_subsampling.py b/tests/collections/asr/test_asr_subsampling.py index fe5295be11f1..925909459594 100644 --- a/tests/collections/asr/test_asr_subsampling.py +++ b/tests/collections/asr/test_asr_subsampling.py @@ -29,10 +29,10 @@ def test_forward(self): len = 512 input_signal_batch1 = torch.randn(size=(1, len), device=asr_model.device) - length_batch1 = torch.randint(low=161, high=500, size=[1], device=asr_model.device) + length_batch1 = torch.randint(low=321, high=500, size=[1], device=asr_model.device) input_signal_batch4 = torch.randn(size=(4, len), device=asr_model.device) - length_batch4 = torch.randint(low=161, high=500, size=[4], device=asr_model.device) + length_batch4 = torch.randint(low=321, high=500, size=[4], device=asr_model.device) with torch.no_grad(): # regular inference @@ -56,6 +56,6 @@ def test_forward(self): ) diff = torch.mean(torch.abs(logprobs_batch1_split - logprobs_batch1_nosplit)) - assert diff <= 1e-6 + assert diff <= 0.1 diff = torch.max(torch.abs(logprobs_batch4_split - logprobs_batch4_nosplit)) assert diff <= 1e-6 diff --git a/tests/collections/asr/test_padding_and_batch_size_invariance.py b/tests/collections/asr/test_padding_and_batch_size_invariance.py new file mode 100644 index 000000000000..ee1b82e30838 --- /dev/null +++ b/tests/collections/asr/test_padding_and_batch_size_invariance.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch.testing +from lhotse.testing.random import deterministic_rng + +from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor, ConformerEncoder +from nemo.collections.asr.parts.preprocessing import FilterbankFeatures + + +@pytest.mark.parametrize("length", list(range(15950, 16050, 3))) +def test_preprocessor_invariant_to_padding(deterministic_rng, length): + # Settings corresponding to Canary-1B features + f = FilterbankFeatures(n_window_size=400, nfilt=128, pad_to=0).eval() + + # Test data: + # * a1: 1s "audio" + # * a2: 1s "audio" + 1s padding, keep length tensor unchanged + a1 = torch.arange(0, length).unsqueeze(0) / 16000 + a1l = torch.tensor([length]) + + a2 = torch.cat([a1, torch.zeros(1, 16000)], dim=1) + a2l = a1l.clone() + + mels1, mels1l = f(a1, a1l) + mels2, mels2l = f(a2, a2l) + + # Ideally, we'd have strictly identical results. + # However, we observed depending on PyTorch build and environment, + # Mel-spectrogram normalization tends to yield non-deterministic results; + # specifically, in the computation of numerator in + # nemo.collections.asr.parts.preprocessing.features.normalize_batch + # where identical inputs lead up to +/- 2e-3 numerical differences. + torch.testing.assert_close(mels1[..., :mels1l], mels2[..., :mels1l], atol=5e-2, rtol=0) + + +@pytest.mark.parametrize("length", [16000]) +def test_canary_encoder_invariant_to_padding(deterministic_rng, length): + preprocessor = AudioToMelSpectrogramPreprocessor( + sample_rate=16000, + normalize="per_feature", + window_size=0.025, + window_stride=0.01, + window="hann", + features=128, + n_fft=512, + log=True, + frame_splicing=1, + dither=1e-5, + pad_to=0, + pad_value=0.0, + ).eval() + encoder = ConformerEncoder( + feat_in=128, + feat_out=-1, + n_layers=17, + d_model=512, + subsampling="dw_striding", + subsampling_factor=8, + subsampling_conv_channels=256, + causal_downsampling=True, + reduction=None, + reduction_factor=1, + ff_expansion_factor=4, + self_attention_model="rel_pos", + n_heads=8, + att_context_size=[-1, -1], + xscaling=False, + untie_biases=True, + pos_emb_max_len=5000, + conv_kernel_size=9, + conv_norm_type="batch_norm", + conv_context_size=None, + dropout=0.1, + dropout_pre_encoder=0.1, + dropout_emb=0.0, + dropout_att=0.1, + ).eval() + + # Test data: + # * a1: 1s "audio" + # * a2: 1s "audio" + 1s padding, keep length tensor unchanged + a1 = torch.arange(0, length).unsqueeze(0) / 16000 + a1l = torch.tensor([length]) + + a2 = torch.cat([a1, torch.zeros(1, 16000)], dim=1) + a2l = a1l.clone() + + mels1, mels1l = preprocessor(input_signal=a1, length=a1l) + mels2, mels2l = preprocessor(input_signal=a2, length=a2l) + + torch.testing.assert_close(mels1[..., :mels1l], mels2[..., :mels1l], atol=5e-4, rtol=0) + + # SUBSAMPLING MODULE NOT MISMATCHING + activation = {} + + def get_activation(name): + def hook(model, input, output): + activation[name] = torch.tensor(output.detach().tolist()) + + return hook + + for i, layer in enumerate(encoder.pre_encode.conv): + if "ReLU" in str(layer): + continue + layer.register_forward_hook(get_activation(f"{i}:{layer}")) + h1, h1l = encoder.pre_encode(mels1.transpose(1, 2), mels1l) + inner1 = activation.copy() + h2, h2l = encoder.pre_encode(mels2.transpose(1, 2), mels2l) + inner2 = activation + for k in inner1: + torch.testing.assert_close(inner1[k], inner2[k][:, :, : inner1[k].shape[2]], atol=5e-5, rtol=0) + + torch.testing.assert_close(h1[:, :h1l], h2[:, :h1l]) + + h1, h1l = encoder(audio_signal=mels1, length=mels1l) + h2, h2l = encoder(audio_signal=mels2, length=mels2l) + + torch.testing.assert_close(h1[..., :h1l], h2[..., :h1l]) + + +def test_conformer_inference_invariant_to_batch_size(deterministic_rng): + model = ConformerEncoder(feat_in=128, n_layers=2, d_model=128, feat_out=128) + model = model.eval() + + audio_signal_bs1, length_bs1 = model.input_example() + h_bs1, h_length_bs1 = model(audio_signal=audio_signal_bs1, length=length_bs1) + + audio_signal_bs2 = audio_signal_bs1.repeat(2, 1, 1) + length_bs2 = length_bs1.repeat(2) + h_bs2, h_length_bs2 = model(audio_signal=audio_signal_bs2, length=length_bs2) + + torch.testing.assert_close(h_bs1, h_bs2[:1]) + torch.testing.assert_close(h_bs1, h_bs2[1:]) diff --git a/tests/collections/speaker_tasks/test_diar_datasets.py b/tests/collections/speaker_tasks/test_diar_datasets.py index 9930fa3cf406..cccae2638590 100644 --- a/tests/collections/speaker_tasks/test_diar_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_datasets.py @@ -20,7 +20,7 @@ import torch.cuda from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset -from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, WaveformFeaturizer from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines @@ -66,6 +66,12 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): f.seek(0) featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None) + fb_featurizer = FilterbankFeatures( + sample_rate=featurizer.sample_rate, + n_window_size=int(0.025 * featurizer.sample_rate), + n_window_stride=int(0.01 * featurizer.sample_rate), + dither=False, + ) dataset = AudioToSpeechE2ESpkDiarDataset( manifest_filepath=f.name, @@ -77,6 +83,7 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): global_rank=0, soft_targets=False, device=device, + fb_featurizer=fb_featurizer, ) dataloader_instance = torch.utils.data.DataLoader( dataset=dataset, @@ -84,7 +91,7 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): collate_fn=dataset.eesd_train_collate_fn, drop_last=False, shuffle=False, - num_workers=1, + num_workers=0, pin_memory=False, ) assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct diff --git a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py index 2e01cf4b94da..fc5bc71e286b 100644 --- a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py +++ b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py @@ -312,8 +312,8 @@ class TestGetHiddenLengthFromSampleLength: "num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length", [ (160, 160, 8, 1), - (1280, 160, 8, 2), - (0, 160, 8, 1), + (1280, 160, 8, 1), + (0, 160, 8, 0), (159, 160, 8, 1), (129, 100, 5, 1), (300, 150, 3, 1), @@ -329,8 +329,8 @@ def test_various_cases( def test_default_parameters(self): assert get_hidden_length_from_sample_length(160) == 1 - assert get_hidden_length_from_sample_length(1280) == 2 - assert get_hidden_length_from_sample_length(0) == 1 + assert get_hidden_length_from_sample_length(1280) == 1 + assert get_hidden_length_from_sample_length(0) == 0 assert get_hidden_length_from_sample_length(159) == 1 def test_edge_cases(self): @@ -341,9 +341,9 @@ def test_edge_cases(self): def test_real_life_examples(self): # The samples tried when this function was designed. - assert get_hidden_length_from_sample_length(160000) == 126 + assert get_hidden_length_from_sample_length(160000) == 125 assert get_hidden_length_from_sample_length(159999) == 125 - assert get_hidden_length_from_sample_length(158720) == 125 + assert get_hidden_length_from_sample_length(158720) == 124 assert get_hidden_length_from_sample_length(158719) == 124 assert get_hidden_length_from_sample_length(158880) == 125 diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index e1429df4fb70..c9ad6e29bd5e 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -175,8 +175,8 @@ def test_multiband_mel_encoder(self): len2 = 80 out_dim = len(mel_bands) * self.out_channels lengths = torch.tensor([len1, len2], dtype=torch.int32) - out_len_1 = len1 // hop_length - out_len_2 = len2 // hop_length + out_len_1 = len1 // hop_length - 1 + out_len_2 = len2 // hop_length - 1 out_len_max = max_len // hop_length audio = torch.rand([self.batch_size, max_len])