diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml new file mode 100644 index 000000000000..c6bc6f1ad1de --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml @@ -0,0 +1,234 @@ +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: streaming_sortformer_diarizer_-.yaml +# (Example) `streaming_sortformer_diarizer_4spk-v2.yaml`. +# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. +# Example: a manifest line for training +# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} +name: "StreamingSortformerDiarizer" +num_workers: 18 +batch_size: 4 + +model: + sample_rate: 16000 + pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model + ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model + max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + streaming_mode: True + + model_defaults: + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. + soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss + labels: null + batch_size: ${batch_size} + shuffle: True + num_workers: ${num_workers} + validation_mode: False + # lhotse config + use_lhotse: False + use_bucketing: True + num_buckets: 10 + bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] + pin_memory: True + min_duration: 10 + max_duration: 90 + batch_duration: 400 + quadratic_duration: 1200 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + validation_ds: + manifest_filepath: ??? + is_tarred: False + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + test_ds: + manifest_filepath: null + is_tarred: False + tarred_audio_filepaths: null + sample_rate: 16000 + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + seq_eval_mode: True + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "NA" + window_size: 0.025 + sample_rate: ${model.sample_rate} + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + + sortformer_modules: + _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules + num_spks: ${model.max_num_of_spks} # Maximum number of speakers the model can handle + dropout_rate: 0.5 # Dropout rate + fc_d_model: ${model.model_defaults.fc_d_model} # Hidden dimension size for Fast Conformer encoder + tf_d_model: ${model.model_defaults.tf_d_model} # Hidden dimension size for Transformer encoder + # Streaming mode parameters + spkcache_len: 188 # Length of speaker cache buffer (total number of frames for all speakers) + fifo_len: 0 # Length of FIFO buffer for streaming processing (0 = disabled) + chunk_len: 188 # Number of frames processed in each streaming chunk + spkcache_update_period: 1 # Speaker cache update period in frames + chunk_left_context: 1 # Number of previous frames for each streaming chunk + chunk_right_context: 1 # Number of future frames for each streaming chunk + # Speaker cache update parameters + spkcache_sil_frames_per_spk: 3 # Number of silence frames allocated per speaker in the speaker cache + scores_add_rnd: 0 # Standard deviation of random noise added to scores in speaker cache update (training only) + pred_score_threshold: 0.25 # Probability threshold for internal scores processing in speaker cache update + max_index: 99999 # Maximum allowed index value for internal processing in speaker cache update + scores_boost_latest: 0.05 # Gain for scores for recently added frames in speaker cache update + sil_threshold: 0.2 # Threshold for determining silence frames to calculate average silence embedding + strong_boost_rate: 0.75 # Rate determining number of frames per speaker that receive strong score boosting + weak_boost_rate: 1.5 # Rate determining number of frames per speaker that receive weak score boosting + min_pos_scores_rate: 0.5 # Rate threshold for dropping overlapping frames when enough non-overlapping exist + # Self-attention parameters (training only) + causal_attn_rate: 0.5 # Proportion of batches that use self-attention with limited right context + causal_attn_rc: 7 # Right context size for self-attention with limited right context + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 17 + d_model: ${model.model_defaults.fc_d_model} + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + # Feed forward module's params + ff_expansion_factor: 4 + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + conv_context_size: null + # Regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + # Set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + transformer_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 18 + hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads + inner_size: 768 + num_attention_heads: 8 + attn_score_dropout: 0.5 + attn_layer_dropout: 0.5 + ffn_dropout: 0.5 + hidden_act: relu + pre_ln: False + pre_ln_final_layer_norm: True + + loss: + _target_: nemo.collections.asr.losses.bce_loss.BCELoss + weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) + reduction: mean + + lr: 0.0001 + optim: + name: adamw + lr: ${model.lr} + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: InverseSquareRootAnnealing + warmup_steps: 500 + warmup_ratio: null + min_lr: 1e-06 + +trainer: + devices: 1 # number of gpus (devices) + accelerator: gpu + max_epochs: 800 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + strategy: ddp_find_unused_parameters_true # Could be "ddp" + accumulate_grad_batches: 1 + deterministic: True + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + +exp_manager: + use_datetime_version: False + exp_dir: null + name: ${name} + resume_if_exists: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_ignore_no_checkpoint: True + create_tensorboard_logger: True + create_checkpoint_callback: True + create_wandb_logger: False + checkpoint_callback_params: + monitor: "val_f1_acc" + mode: "max" + save_top_k: 9 + every_n_epochs: 1 + wandb_logger_kwargs: + resume: True + name: null + project: null \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/diar_streaming_sortformer_4spk-v2_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/diar_streaming_sortformer_4spk-v2_callhome-part1.yaml new file mode 100644 index 000000000000..6083004f613d --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/diar_streaming_sortformer_4spk-v2_callhome-part1.yaml @@ -0,0 +1,11 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh +parameters: + onset: 0.641 # Onset threshold for detecting the beginning and end of a speech + offset: 0.561 # Offset threshold for detecting the end of a speech + pad_onset: 0.229 # Adding durations before each speech segment + pad_offset: 0.079 # Adding durations after each speech segment + min_duration_on: 0.511 # Threshold for small non-speech deletion + min_duration_off: 0.296 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml new file mode 100644 index 000000000000..a9e10471e5b6 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml @@ -0,0 +1,11 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477). +parameters: + onset: 0.56 # Onset threshold for detecting the beginning and end of a speech + offset: 1.0 # Offset threshold for detecting the end of a speech + pad_onset: 0.063 # Adding durations before each speech segment + pad_offset: 0.002 # Adding durations after each speech segment + min_duration_on: 0.007 # Threshold for small non-speech deletion + min_duration_off: 0.151 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index d00509520cd6..f96bb01bd710 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -92,7 +92,7 @@ class DiarizationConfig: log: bool = False # If True, log will be printed use_lhotse: bool = True - batch_duration: int = 33000 + batch_duration: int = 100000 # Eval Settings: (0.25, False) should be default setting for sortformer eval. collar: float = 0.25 # Collar in seconds for DER calculation @@ -100,7 +100,7 @@ class DiarizationConfig: # Streaming diarization configs spkcache_len: int = 188 - spkcache_refresh_rate: int = 144 + spkcache_update_period: int = 144 fifo_len: int = 188 chunk_len: int = 6 chunk_left_context: int = 1 @@ -386,7 +386,10 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model.sortformer_modules.chunk_right_context = cfg.chunk_right_context diar_model.sortformer_modules.fifo_len = cfg.fifo_len diar_model.sortformer_modules.log = cfg.log - diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate + diar_model.sortformer_modules.spkcache_update_period = cfg.spkcache_update_period + + # Check if the streaming parameters are valid + diar_model.sortformer_modules._check_streaming_parameters() postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) tensor_path, model_id, tensor_filename = get_tensor_path(cfg) diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 63b6db61bbda..2888fd04436d 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -49,42 +49,6 @@ __all__ = ['SortformerEncLabelModel'] -def concat_and_pad(embs: List[torch.Tensor], lengths: List[torch.Tensor]): - """ - Concatenates lengths[i] first embeddings of embs[i], and pads the rest elements with zeros. - - Args: - embs: List of embeddings Tensors of (batch_size, n_frames, emb_dim) shape - lengths: List of lengths Tensors of (batch_size,) shape - - Returns: - output: concatenated embeddings Tensor of (batch_size, n_frames, emb_dim) shape - total_lengths: output lengths Tensor of (batch_size,) shape - """ - - if len(embs) != len(lengths): - raise ValueError( - f"Length lists must have the same length, but got len(embs) - {len(embs)} " - f"and len(lengths) - {len(lengths)}." - ) - device, dtype = embs[0].device, embs[0].dtype - batch_size, emb_dim = embs[0].shape[0], embs[0].shape[2] - - total_lengths = torch.sum(torch.stack(lengths), dim=0) - sig_length = total_lengths.max().item() - - output = torch.zeros(batch_size, sig_length, emb_dim, device=device, dtype=dtype) - start_indices = torch.zeros(batch_size, dtype=torch.int64, device=device) - - for emb, length in zip(embs, lengths): - end_indices = start_indices + length - for batch_idx in range(batch_size): - output[batch_idx, start_indices[batch_idx] : end_indices[batch_idx]] = emb[batch_idx, : length[batch_idx]] - start_indices = end_indices - - return output, total_lengths - - class SortformerEncLabelModel(ModelPT, ExportableEncDecModel, SpkDiarizationMixin): """ Encoder class for Sortformer diarization model. @@ -160,7 +124,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations self.max_batch_dur = self._cfg.get("max_batch_dur", 20000) - self.concat_and_pad_script = torch.jit.script(concat_and_pad) + self.concat_and_pad_script = torch.jit.script(self.sortformer_modules.concat_and_pad) def _init_loss_weights(self): pil_weight = self._cfg.get("pil_weight", 0.0) @@ -779,9 +743,11 @@ def forward_streaming_step( ) if self.async_streaming: - spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths = concat_and_pad( - [streaming_state.spkcache, streaming_state.fifo, chunk_pre_encode_embs], - [streaming_state.spkcache_lengths, streaming_state.fifo_lengths, chunk_pre_encode_lengths], + spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths = ( + self.sortformer_modules.concat_and_pad( + [streaming_state.spkcache, streaming_state.fifo, chunk_pre_encode_embs], + [streaming_state.spkcache_lengths, streaming_state.fifo_lengths, chunk_pre_encode_lengths], + ) ) else: spkcache_fifo_chunk_pre_encode_embs = self.sortformer_modules.concat_embs( @@ -790,7 +756,6 @@ def forward_streaming_step( spkcache_fifo_chunk_pre_encode_lengths = ( streaming_state.spkcache.shape[1] + streaming_state.fifo.shape[1] + chunk_pre_encode_lengths ) - spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.frontend_encoder( processed_signal=spkcache_fifo_chunk_pre_encode_embs, processed_signal_length=spkcache_fifo_chunk_pre_encode_lengths, @@ -843,6 +808,13 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict: Returns: (dict): A dictionary containing the following training metrics. """ + if preds.shape[1] < targets.shape[1]: + logging.info( + f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). " + "Truncating targets and clamping target_lens." + ) + targets = targets[:, : preds.shape[1], :] + target_lens = target_lens.clamp(max=preds.shape[1]) targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) @@ -908,6 +880,13 @@ def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: Returns: val_metrics (dict): A dictionary containing the following validation metrics """ + if preds.shape[1] < targets.shape[1]: + logging.info( + f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). " + "Truncating targets and clamping target_lens." + ) + targets = targets[:, : preds.shape[1], :] + target_lens = target_lens.clamp(max=preds.shape[1]) targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) @@ -968,6 +947,29 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): self.validation_step_outputs.append(val_metrics) return val_metrics + def test_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): + """ + Performs a single validation step. + + This method processes a batch of data during the validation phase. It forward passes + the audio signal through the model, computes various validation metrics, and stores + these metrics for later aggregation. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + dataloader_idx (int, optional): The index of the dataloader in case of multiple + validation dataloaders. Defaults to 0. + + Returns: + dict: A dictionary containing various validation metrics for this batch. + """ + return self.validation_step(batch, batch_idx, dataloader_idx) + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): if not outputs: logging.warning(f"`outputs` is None; empty outputs for dataloader={dataloader_idx}") @@ -1009,6 +1011,13 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target target_lens (torch.Tensor): Lengths of target sequences. Shape: (batch_size,) """ + if preds.shape[1] < targets.shape[1]: + logging.info( + f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). " + "Truncating targets and clamping target_lens." + ) + targets = targets[:, : preds.shape[1], :] + target_lens = target_lens.clamp(max=preds.shape[1]) targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) self._accuracy_test(preds, targets_pil, target_lens) diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 7d2dd65c6f85..2a7ffeba5d3c 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -578,9 +578,7 @@ def forward( ) if bypass_pre_encode: - self.update_max_seq_length( - seq_length=audio_signal.size(2) * self.subsampling_factor, device=audio_signal.device - ) + self.update_max_seq_length(seq_length=audio_signal.size(1), device=audio_signal.device) else: self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) return self.forward_internal( diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 0368b83847b2..e68f8dc2aea1 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -41,6 +41,8 @@ class StreamingSortformerState: fifo_lengths (torch.Tensor): Lengths of the FIFO queue fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts spk_perm (torch.Tensor): Speaker permutation information for the speaker cache + mean_sil_emb (torch.Tensor): Mean silence embedding + n_sil_frames (torch.Tensor): Number of silence frames """ spkcache = None # Speaker cache to store embeddings from start @@ -50,6 +52,8 @@ class StreamingSortformerState: fifo_lengths = None fifo_preds = None spk_perm = None + mean_sil_emb = None + n_sil_frames = None class SortformerModules(NeuralModule, Exportable): @@ -74,9 +78,9 @@ def __init__( tf_d_model: int = 192, subsampling_factor: int = 8, spkcache_len: int = 188, - fifo_len: int = 0, - chunk_len: int = 376, - spkcache_refresh_rate: int = 1, + fifo_len: int = 188, + chunk_len: int = 12, + spkcache_update_period: int = 1, chunk_left_context: int = 1, chunk_right_context: int = 1, spkcache_sil_frames_per_spk: int = 3, @@ -112,7 +116,7 @@ def __init__( self.chunk_left_context = chunk_left_context self.chunk_right_context = chunk_right_context self.spkcache_sil_frames_per_spk = spkcache_sil_frames_per_spk - self.spkcache_refresh_rate = spkcache_refresh_rate + self.spkcache_update_period = spkcache_update_period self.causal_attn_rate = causal_attn_rate self.causal_attn_rc = causal_attn_rc self.scores_add_rnd = scores_add_rnd @@ -123,8 +127,49 @@ def __init__( self.strong_boost_rate = strong_boost_rate self.weak_boost_rate = weak_boost_rate self.min_pos_scores_rate = min_pos_scores_rate + self._check_streaming_parameters() + + def _check_streaming_parameters(self): + """ + Check if there are any illegal parameter combinations. + + Restrictions: + - All streaming parameters should be non-negative integers. + - Chunk length and speaker cache update period should be greater than 0. + - Speaker cache length should be greater than or equal to `(1 + spkcache_sil_frames_per_spk ) * n_spk`. + - The effective range of self.spkcache_update_period is: chunk_len <= spkcache_update_period <= fifo_len + chunk_len + """ + param_constraints = { + 'spkcache_len': (1 + self.spkcache_sil_frames_per_spk) * self.n_spk, + 'fifo_len': 0, + 'chunk_len': 1, + 'spkcache_update_period': 1, + 'chunk_left_context': 0, + 'chunk_right_context': 0, + 'spkcache_sil_frames_per_spk': 0, + } + + for param, min_val in param_constraints.items(): + val = getattr(self, param) + if not isinstance(val, int): + raise TypeError(f"Parameter '{param}' must be an integer, but got {param}: {val}") + if val < min_val: + raise ValueError(f"Parameter '{param}' must be at least {min_val}, but got {val}.") + + if self.spkcache_update_period < self.chunk_len: + logging.warning( + f"spkcache_update_period ({self.spkcache_update_period}) is less than chunk_len ({self.chunk_len}). " + f"The effective update period will be {self.chunk_len}." + ) + if self.spkcache_update_period > self.fifo_len + self.chunk_len: + logging.warning( + f"spkcache_update_period ({self.spkcache_update_period}) is greater than " + f"fifo_len + chunk_len ({self.fifo_len + self.chunk_len}). " + f"The effective update period will be {self.fifo_len + self.chunk_len}." + ) - def length_to_mask(self, lengths, max_length: int): + @staticmethod + def length_to_mask(lengths, max_length: int): """ Convert length values to encoder mask input tensor @@ -211,8 +256,8 @@ def forward_speaker_sigmoids(self, hidden_out): preds = F.sigmoid(spk_preds) return preds + @staticmethod def concat_embs( - self, list_of_tensors=List[torch.Tensor], return_lengths: bool = False, dim: int = 1, @@ -237,6 +282,50 @@ def concat_embs( else: return embs + @staticmethod + def concat_and_pad(embs: List[torch.Tensor], lengths: List[torch.Tensor]): + """ + Concatenates lengths[i] first embeddings of embs[i], and pads the rest elements with zeros. + + Args: + embs: List of embeddings Tensors of (batch_size, n_frames, emb_dim) shape + lengths: List of lengths Tensors of (batch_size,) shape + + Returns: + output: concatenated embeddings Tensor of (batch_size, n_frames, emb_dim) shape + total_lengths: output lengths Tensor of (batch_size,) shape + """ + # Error handling for mismatched list lengths + if len(embs) != len(lengths): + raise ValueError( + f"Length lists must have the same length, but got len(embs) - {len(embs)} " + f"and len(lengths) - {len(lengths)}." + ) + # Handle empty lists + if len(embs) == 0 or len(lengths) == 0: + raise ValueError( + f"Cannot concatenate empty lists of embeddings or lengths: embs - {len(embs)}, lengths - {len(lengths)}" + ) + + device, dtype = embs[0].device, embs[0].dtype + batch_size, emb_dim = embs[0].shape[0], embs[0].shape[2] + + total_lengths = torch.sum(torch.stack(lengths), dim=0) + sig_length = total_lengths.max().item() + + output = torch.zeros(batch_size, sig_length, emb_dim, device=device, dtype=dtype) + start_indices = torch.zeros(batch_size, dtype=torch.int64, device=device) + + for emb, length in zip(embs, lengths): + end_indices = start_indices + length + for batch_idx in range(batch_size): + output[batch_idx, start_indices[batch_idx] : end_indices[batch_idx]] = emb[ + batch_idx, : length[batch_idx] + ] + start_indices = end_indices + + return output, total_lengths + def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None): """ Initializes StreamingSortformerState with empty tensors or zero-valued tensors. @@ -259,9 +348,12 @@ def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = Fals else: streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device) streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device) + streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device) + streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) return streaming_state - def apply_mask_to_preds(self, spkcache_fifo_chunk_preds, spkcache_fifo_chunk_fc_encoder_lengths): + @staticmethod + def apply_mask_to_preds(spkcache_fifo_chunk_preds, spkcache_fifo_chunk_fc_encoder_lengths): """ Applies mask to speaker cache and FIFO queue to ensure that only valid frames are considered for predictions from the model. @@ -312,12 +404,8 @@ def streaming_update_async(self, streaming_state, chunk, chunk_lengths, preds, l chunk.shape[1] - lc - rc, ) - if self.fifo_len == 0: - max_pop_out_len = max_chunk_len - elif self.spkcache_refresh_rate == 0: - max_pop_out_len = self.fifo_len - else: - max_pop_out_len = min(max(self.spkcache_refresh_rate, max_chunk_len), self.fifo_len) + max_pop_out_len = max(self.spkcache_update_period, max_chunk_len) + max_pop_out_len = min(max_pop_out_len, max_chunk_len + max_fifo_len) streaming_state.fifo_preds = torch.zeros((batch_size, max_fifo_len, n_spk), device=preds.device) chunk_preds = torch.zeros((batch_size, max_chunk_len, n_spk), device=preds.device) @@ -354,47 +442,63 @@ def streaming_update_async(self, streaming_state, chunk, chunk_lengths, preds, l ] if fifo_len + chunk_len > max_fifo_len: # move pop_out_len first frames of FIFO queue to speaker cache - pop_out_len = min(max_pop_out_len, fifo_len + chunk_len) + pop_out_len = self.spkcache_update_period + pop_out_len = max(pop_out_len, max_chunk_len - max_fifo_len + fifo_len) + pop_out_len = min(pop_out_len, fifo_len + chunk_len) streaming_state.spkcache_lengths[batch_index] += pop_out_len - updated_spkcache[batch_index, spkcache_len : spkcache_len + pop_out_len, :] = updated_fifo[ - batch_index, :pop_out_len, : - ] + pop_out_embs = updated_fifo[batch_index, :pop_out_len, :] + pop_out_preds = updated_fifo_preds[batch_index, :pop_out_len, :] + ( + streaming_state.mean_sil_emb[batch_index : batch_index + 1], + streaming_state.n_sil_frames[batch_index : batch_index + 1], + ) = self._get_silence_profile( + streaming_state.mean_sil_emb[batch_index : batch_index + 1], + streaming_state.n_sil_frames[batch_index : batch_index + 1], + pop_out_embs.unsqueeze(0), + pop_out_preds.unsqueeze(0), + ) + updated_spkcache[batch_index, spkcache_len : spkcache_len + pop_out_len, :] = pop_out_embs if updated_spkcache_preds[batch_index, 0, 0] >= 0: # speaker cache already compressed at least once - updated_spkcache_preds[batch_index, spkcache_len : spkcache_len + pop_out_len, :] = ( - updated_fifo_preds[batch_index, :pop_out_len, :] - ) + updated_spkcache_preds[batch_index, spkcache_len : spkcache_len + pop_out_len, :] = pop_out_preds elif spkcache_len + pop_out_len > self.spkcache_len: # will compress speaker cache for the first time updated_spkcache_preds[batch_index, :spkcache_len, :] = preds[batch_index, :spkcache_len, :] - updated_spkcache_preds[batch_index, spkcache_len : spkcache_len + pop_out_len, :] = ( - updated_fifo_preds[batch_index, :pop_out_len, :] - ) + updated_spkcache_preds[batch_index, spkcache_len : spkcache_len + pop_out_len, :] = pop_out_preds streaming_state.fifo_lengths[batch_index] -= pop_out_len new_fifo_len = streaming_state.fifo_lengths[batch_index].item() updated_fifo[batch_index, :new_fifo_len, :] = updated_fifo[ batch_index, pop_out_len : pop_out_len + new_fifo_len, : ].clone() + updated_fifo_preds[batch_index, :new_fifo_len, :] = updated_fifo_preds[ + batch_index, pop_out_len : pop_out_len + new_fifo_len, : + ].clone() updated_fifo[batch_index, new_fifo_len:, :] = 0 + updated_fifo_preds[batch_index, new_fifo_len:, :] = 0 streaming_state.fifo = updated_fifo[:, :max_fifo_len, :] + streaming_state.fifo_preds = updated_fifo_preds[:, :max_fifo_len, :] # update speaker cache need_compress = streaming_state.spkcache_lengths > self.spkcache_len - streaming_state.spkcache = updated_spkcache[:, : self.spkcache_len :, :] - streaming_state.spkcache_preds = updated_spkcache_preds[:, : self.spkcache_len :, :] + streaming_state.spkcache = updated_spkcache[:, : self.spkcache_len, :] + streaming_state.spkcache_preds = updated_spkcache_preds[:, : self.spkcache_len, :] idx = torch.where(need_compress)[0] if len(idx) > 0: streaming_state.spkcache[idx], streaming_state.spkcache_preds[idx], _ = self._compress_spkcache( - emb_seq=updated_spkcache[idx], preds=updated_spkcache_preds[idx], permute_spk=False + emb_seq=updated_spkcache[idx], + preds=updated_spkcache_preds[idx], + mean_sil_emb=streaming_state.mean_sil_emb[idx], + permute_spk=False, ) streaming_state.spkcache_lengths[idx] = streaming_state.spkcache_lengths[idx].clamp(max=self.spkcache_len) if self.log: logging.info( - f"MC spkcache: {streaming_state.spkcache.shape}, " - f"chunk: {chunk.shape}, fifo: {streaming_state.fifo.shape}, " + f"spkcache: {streaming_state.spkcache.shape}, spkcache_lengths: {streaming_state.spkcache_lengths}, " + f"fifo: {streaming_state.fifo.shape}, fifo_lengths: {streaming_state.fifo_lengths}, " + f"chunk: {chunk.shape}, chunk_lengths: {chunk_lengths}, " f"chunk_preds: {chunk_preds.shape}" ) @@ -440,23 +544,24 @@ def streaming_update(self, streaming_state, chunk, preds, lc: int = 0, rc: int = chunk = chunk[:, lc : chunk_len + lc] chunk_preds = preds[:, spkcache_len + fifo_len + lc : spkcache_len + fifo_len + chunk_len + lc] - # pop_out_len is the number of frames we will pop out from FIFO to update spkcache - if self.fifo_len == 0: - pop_out_len = chunk_len - elif self.spkcache_refresh_rate == 0: - pop_out_len = self.fifo_len - else: - pop_out_len = min(max(self.spkcache_refresh_rate, chunk_len), self.fifo_len) - # append chunk to fifo streaming_state.fifo = torch.cat([streaming_state.fifo, chunk], dim=1) streaming_state.fifo_preds = torch.cat([streaming_state.fifo_preds, chunk_preds], dim=1) if fifo_len + chunk_len > self.fifo_len: # extract pop_out_len first frames from FIFO queue + pop_out_len = self.spkcache_update_period + pop_out_len = max(pop_out_len, chunk_len - self.fifo_len + fifo_len) pop_out_len = min(pop_out_len, fifo_len + chunk_len) + pop_out_embs = streaming_state.fifo[:, :pop_out_len] pop_out_preds = streaming_state.fifo_preds[:, :pop_out_len] + streaming_state.mean_sil_emb, streaming_state.n_sil_frames = self._get_silence_profile( + streaming_state.mean_sil_emb, + streaming_state.n_sil_frames, + pop_out_embs, + pop_out_preds, + ) streaming_state.fifo = streaming_state.fifo[:, pop_out_len:] streaming_state.fifo_preds = streaming_state.fifo_preds[:, pop_out_len:] @@ -471,15 +576,15 @@ def streaming_update(self, streaming_state, chunk, preds, lc: int = 0, rc: int = self._compress_spkcache( emb_seq=streaming_state.spkcache, preds=streaming_state.spkcache_preds, + mean_sil_emb=streaming_state.mean_sil_emb, permute_spk=self.training, ) ) if self.log: logging.info( - f"spkcache: {streaming_state.spkcache.shape}, " - f"chunk: {chunk.shape}, fifo: {streaming_state.fifo.shape}, " - f"chunk_preds: {chunk_preds.shape}" + f"spkcache: {streaming_state.spkcache.shape}, fifo: {streaming_state.fifo.shape}, " + f"chunk: {chunk.shape}, chunk_preds: {chunk_preds.shape}" ) return streaming_state, chunk_preds @@ -509,28 +614,38 @@ def _boost_topk_scores( scores[batch_indices, topk_indices, speaker_indices] -= scale_factor * math.log(offset) return scores - def _get_silence_profile(self, emb_seq, preds): + def _get_silence_profile(self, mean_sil_emb, n_sil_frames, emb_seq, preds): """ - Get mean silence embedding from emb_seq sequence. + Get updated mean silence embedding and number of silence frames from emb_seq sequence. Embeddings are considered as silence if sum of corresponding preds is lower than self.sil_threshold. Args: + mean_sil_emb (torch.Tensor): Previous mean silence embedding tensor + Shape: (batch_size, emb_dim) + n_sil_frames (torch.Tensor): Previous number of silence frames + Shape: (batch_size) emb_seq (torch.Tensor): Tensor containing sequence of embeddings Shape: (batch_size, n_frames, emb_dim) preds (torch.Tensor): Tensor containing speaker activity probabilities Shape: (batch_size, n_frames, n_spk) Returns: - mean_sil_emb (torch.Tensor): Mean silence embedding tensor + mean_sil_emb (torch.Tensor): Updated mean silence embedding tensor Shape: (batch_size, emb_dim) + n_sil_frames (torch.Tensor): Updated number of silence frames + Shape: (batch_size) """ is_sil = preds.sum(dim=2) < self.sil_threshold - is_sil = is_sil.unsqueeze(-1) - emb_seq_sil = torch.where(is_sil, emb_seq, torch.tensor(0.0)) # (batch_size, n_frames, emb_dim) - emb_seq_sil_sum = emb_seq_sil.sum(dim=1) # (batch_size, emb_dim) - sil_count = is_sil.sum(dim=1).clamp(min=1) # (batch_size) - mean_sil_emb = emb_seq_sil_sum / sil_count # (batch_size, emb_dim) - return mean_sil_emb + sil_count = is_sil.sum(dim=1) + has_new_sil = sil_count > 0 + if not has_new_sil.any(): + return mean_sil_emb, n_sil_frames + sil_emb_sum = torch.sum(emb_seq * is_sil.unsqueeze(-1), dim=1) + upd_n_sil_frames = n_sil_frames + sil_count + old_sil_emb_sum = mean_sil_emb * n_sil_frames.unsqueeze(1) + total_sil_sum = old_sil_emb_sum + sil_emb_sum + upd_mean_sil_emb = total_sil_sum / torch.clamp(upd_n_sil_frames.unsqueeze(1), min=1) + return upd_mean_sil_emb, upd_n_sil_frames def _get_log_pred_scores(self, preds): """ @@ -584,7 +699,7 @@ def _get_topk_indices(self, scores): topk_indices_sorted[is_disabled] = 0 # Set a placeholder index to make gather work return topk_indices_sorted, is_disabled - def _gather_spkcache_and_preds(self, emb_seq, preds, topk_indices, is_disabled): + def _gather_spkcache_and_preds(self, emb_seq, preds, topk_indices, is_disabled, mean_sil_emb): """ Gather embeddings from emb_seq and speaker activities from preds corresponding to topk_indices. For disabled frames, use mean silence embedding and zero probability instead. @@ -598,6 +713,8 @@ def _gather_spkcache_and_preds(self, emb_seq, preds, topk_indices, is_disabled): Shape: (batch_size, spkcache_len) is_disabled (torch.Tensor): Tensor containing binary mask for disabled frames Shape: (batch_size, spkcache_len) + mean_sil_emb (torch.Tensor): Tensor containing mean silence embedding + Shape: (batch_size, emb_dim) Returns: emb_seq_gathered (torch.Tensor): Tensor containing gathered embeddings. @@ -611,7 +728,6 @@ def _gather_spkcache_and_preds(self, emb_seq, preds, topk_indices, is_disabled): emb_dim, n_spk = emb_seq.shape[2], preds.shape[2] indices_expanded_emb = topk_indices.unsqueeze(-1).expand(-1, -1, emb_dim) emb_seq_gathered = torch.gather(emb_seq, 1, indices_expanded_emb) # (batch_size, spkcache_len, emb_dim) - mean_sil_emb = self._get_silence_profile(emb_seq, preds) # Compute mean silence embedding mean_sil_emb_expanded = mean_sil_emb.unsqueeze(1).expand(-1, self.spkcache_len, -1) emb_seq_gathered = torch.where(is_disabled.unsqueeze(-1), mean_sil_emb_expanded, emb_seq_gathered) @@ -700,7 +816,7 @@ def _permute_speakers(self, scores, max_perm_index): scores = torch.stack(scores_list).to(scores.device) return scores, spk_perm - def _compress_spkcache(self, emb_seq, preds, permute_spk: bool = False): + def _compress_spkcache(self, emb_seq, preds, mean_sil_emb, permute_spk: bool = False): """ Compress speaker cache for streaming inference. Keep spkcache_len most important frames out of input n_frames, based on preds. @@ -710,6 +826,8 @@ def _compress_spkcache(self, emb_seq, preds, permute_spk: bool = False): Shape: (batch_size, n_frames, emb_dim) preds (torch.Tensor): Tensor containing n_frames > spkcache_len speaker activity probabilities Shape: (batch_size, n_frames, n_spk) + mean_sil_emb (torch.Tensor): Tensor containing mean silence embedding + Shape: (batch_size, emb_dim) permute_spk (bool): If true, will generate a random permutation of existing speakers Returns: @@ -753,5 +871,7 @@ def _compress_spkcache(self, emb_seq, preds, permute_spk: bool = False): scores = torch.cat([scores, pad], dim=1) # (batch_size, n_frames + spkcache_sil_frames_per_spk, n_spk) topk_indices, is_disabled = self._get_topk_indices(scores) - spkcache, spkcache_preds = self._gather_spkcache_and_preds(emb_seq, preds, topk_indices, is_disabled) + spkcache, spkcache_preds = self._gather_spkcache_and_preds( + emb_seq, preds, topk_indices, is_disabled, mean_sil_emb + ) return spkcache, spkcache_preds, spk_perm