Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture.
# Model name convention for Sortformer Diarizer: streaming_sortformer_diarizer_<speaker count limit>-<version>.yaml
# (Example) `streaming_sortformer_diarizer_4spk-v2.yaml`.
# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers.
# Example: a manifest line for training
# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"}
name: "StreamingSortformerDiarizer"
num_workers: 18
batch_size: 4

model:
sample_rate: 16000
pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model
ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model
max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4
streaming_mode: True

model_defaults:
fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder
tf_d_model: 192 # Hidden dimension size of the Transformer Encoder

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity.
soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss
labels: null
batch_size: ${batch_size}
shuffle: True
num_workers: ${num_workers}
validation_mode: False
# lhotse config
use_lhotse: False
use_bucketing: True
num_buckets: 10
bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90]
pin_memory: True
min_duration: 10
max_duration: 90
batch_duration: 400
quadratic_duration: 1200
bucket_buffer_size: 20000
shuffle_buffer_size: 10000
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

validation_ds:
manifest_filepath: ???
is_tarred: False
tarred_audio_filepaths: null
sample_rate: ${model.sample_rate}
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes.
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

test_ds:
manifest_filepath: null
is_tarred: False
tarred_audio_filepaths: null
sample_rate: 16000
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
seq_eval_mode: True
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
normalize: "NA"
window_size: 0.025
sample_rate: ${model.sample_rate}
window_stride: 0.01
window: "hann"
features: 128
n_fft: 512
frame_splicing: 1
dither: 0.00001

sortformer_modules:
_target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules
num_spks: ${model.max_num_of_spks} # Maximum number of speakers the model can handle
dropout_rate: 0.5 # Dropout rate
fc_d_model: ${model.model_defaults.fc_d_model} # Hidden dimension size for Fast Conformer encoder
tf_d_model: ${model.model_defaults.tf_d_model} # Hidden dimension size for Transformer encoder
# Streaming mode parameters
spkcache_len: 188 # Length of speaker cache buffer (total number of frames for all speakers)
fifo_len: 0 # Length of FIFO buffer for streaming processing (0 = disabled)
chunk_len: 188 # Number of frames processed in each streaming chunk
spkcache_update_period: 1 # Speaker cache update period in frames
chunk_left_context: 1 # Number of previous frames for each streaming chunk
chunk_right_context: 1 # Number of future frames for each streaming chunk
# Speaker cache update parameters
spkcache_sil_frames_per_spk: 3 # Number of silence frames allocated per speaker in the speaker cache
scores_add_rnd: 0 # Standard deviation of random noise added to scores in speaker cache update (training only)
pred_score_threshold: 0.25 # Probability threshold for internal scores processing in speaker cache update
max_index: 99999 # Maximum allowed index value for internal processing in speaker cache update
scores_boost_latest: 0.05 # Gain for scores for recently added frames in speaker cache update
sil_threshold: 0.2 # Threshold for determining silence frames to calculate average silence embedding
strong_boost_rate: 0.75 # Rate determining number of frames per speaker that receive strong score boosting
weak_boost_rate: 1.5 # Rate determining number of frames per speaker that receive weak score boosting
min_pos_scores_rate: 0.5 # Rate threshold for dropping overlapping frames when enough non-overlapping exist
# Self-attention parameters (training only)
causal_attn_rate: 0.5 # Proportion of batches that use self-attention with limited right context
causal_attn_rc: 7 # Right context size for self-attention with limited right context

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1
n_layers: 17
d_model: ${model.model_defaults.fc_d_model}

# Sub-sampling parameters
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false
# Feed forward module's params
ff_expansion_factor: 4
# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
conv_context_size: null
# Regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules
# Set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

transformer_encoder:
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
num_layers: 18
hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads
inner_size: 768
num_attention_heads: 8
attn_score_dropout: 0.5
attn_layer_dropout: 0.5
ffn_dropout: 0.5
hidden_act: relu
pre_ln: False
pre_ln_final_layer_norm: True

loss:
_target_: nemo.collections.asr.losses.bce_loss.BCELoss
weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5])
reduction: mean

lr: 0.0001
optim:
name: adamw
lr: ${model.lr}
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

sched:
name: InverseSquareRootAnnealing
warmup_steps: 500
warmup_ratio: null
min_lr: 1e-06

trainer:
devices: 1 # number of gpus (devices)
accelerator: gpu
max_epochs: 800
max_steps: -1 # computed at runtime if not set
num_nodes: 1
strategy: ddp_find_unused_parameters_true # Could be "ddp"
accumulate_grad_batches: 1
deterministic: True
enable_checkpointing: False
logger: False
log_every_n_steps: 1 # Interval of logging.
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations

exp_manager:
use_datetime_version: False
exp_dir: null
name: ${name}
resume_if_exists: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
resume_ignore_no_checkpoint: True
create_tensorboard_logger: True
create_checkpoint_callback: True
create_wandb_logger: False
checkpoint_callback_params:
monitor: "val_f1_acc"
mode: "max"
save_top_k: 9
every_n_epochs: 1
wandb_logger_kwargs:
resume: True
name: null
project: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Postprocessing parameters for timestamp outputs from speaker diarization models.
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh
parameters:
onset: 0.641 # Onset threshold for detecting the beginning and end of a speech
offset: 0.561 # Offset threshold for detecting the end of a speech
pad_onset: 0.229 # Adding durations before each speech segment
pad_offset: 0.079 # Adding durations after each speech segment
min_duration_on: 0.511 # Threshold for small non-speech deletion
min_duration_off: 0.296 # Threshold for short speech segment deletion
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Postprocessing parameters for timestamp outputs from speaker diarization models.
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477).
parameters:
onset: 0.56 # Onset threshold for detecting the beginning and end of a speech
offset: 1.0 # Offset threshold for detecting the end of a speech
pad_onset: 0.063 # Adding durations before each speech segment
pad_offset: 0.002 # Adding durations after each speech segment
min_duration_on: 0.007 # Threshold for small non-speech deletion
min_duration_off: 0.151 # Threshold for short speech segment deletion
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ class DiarizationConfig:
log: bool = False # If True, log will be printed

use_lhotse: bool = True
batch_duration: int = 33000
batch_duration: int = 100000

# Eval Settings: (0.25, False) should be default setting for sortformer eval.
collar: float = 0.25 # Collar in seconds for DER calculation
ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments

# Streaming diarization configs
spkcache_len: int = 188
spkcache_refresh_rate: int = 144
spkcache_update_period: int = 144
fifo_len: int = 188
chunk_len: int = 6
chunk_left_context: int = 1
Expand Down Expand Up @@ -386,7 +386,10 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
diar_model.sortformer_modules.chunk_right_context = cfg.chunk_right_context
diar_model.sortformer_modules.fifo_len = cfg.fifo_len
diar_model.sortformer_modules.log = cfg.log
diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate
diar_model.sortformer_modules.spkcache_update_period = cfg.spkcache_update_period

# Check if the streaming parameters are valid
diar_model.sortformer_modules._check_streaming_parameters()

postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml)
tensor_path, model_id, tensor_filename = get_tensor_path(cfg)
Expand Down
Loading
Loading