Skip to content

Commit ac0534f

Browse files
tango4jnasretdinovr
authored andcommitted
Streaming Sortformer release PR01: uploading bugfixes, refactored variables and yaml file name changes (NVIDIA-NeMo#14416)
* Uploading bugfixes, refactored vars and yamlfile name changes Signed-off-by: taejinp <[email protected]> * Adding the missing offline pp yamls Signed-off-by: taejinp <[email protected]> --------- Signed-off-by: taejinp <[email protected]>
1 parent e9d0b04 commit ac0534f

File tree

7 files changed

+485
-99
lines changed

7 files changed

+485
-99
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture.
2+
# Model name convention for Sortformer Diarizer: streaming_sortformer_diarizer_<speaker count limit>-<version>.yaml
3+
# (Example) `streaming_sortformer_diarizer_4spk-v2.yaml`.
4+
# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers.
5+
# Example: a manifest line for training
6+
# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"}
7+
name: "StreamingSortformerDiarizer"
8+
num_workers: 18
9+
batch_size: 4
10+
11+
model:
12+
sample_rate: 16000
13+
pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model
14+
ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model
15+
max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4
16+
streaming_mode: True
17+
18+
model_defaults:
19+
fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder
20+
tf_d_model: 192 # Hidden dimension size of the Transformer Encoder
21+
22+
train_ds:
23+
manifest_filepath: ???
24+
sample_rate: ${model.sample_rate}
25+
num_spks: ${model.max_num_of_spks}
26+
session_len_sec: 90 # Maximum session length in seconds
27+
soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity.
28+
soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss
29+
labels: null
30+
batch_size: ${batch_size}
31+
shuffle: True
32+
num_workers: ${num_workers}
33+
validation_mode: False
34+
# lhotse config
35+
use_lhotse: False
36+
use_bucketing: True
37+
num_buckets: 10
38+
bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90]
39+
pin_memory: True
40+
min_duration: 10
41+
max_duration: 90
42+
batch_duration: 400
43+
quadratic_duration: 1200
44+
bucket_buffer_size: 20000
45+
shuffle_buffer_size: 10000
46+
window_stride: ${model.preprocessor.window_stride}
47+
subsampling_factor: ${model.encoder.subsampling_factor}
48+
49+
validation_ds:
50+
manifest_filepath: ???
51+
is_tarred: False
52+
tarred_audio_filepaths: null
53+
sample_rate: ${model.sample_rate}
54+
num_spks: ${model.max_num_of_spks}
55+
session_len_sec: 90 # Maximum session length in seconds
56+
soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes.
57+
soft_targets: False
58+
labels: null
59+
batch_size: ${batch_size}
60+
shuffle: False
61+
num_workers: ${num_workers}
62+
validation_mode: True
63+
# lhotse config
64+
use_lhotse: False
65+
use_bucketing: False
66+
drop_last: False
67+
pin_memory: True
68+
window_stride: ${model.preprocessor.window_stride}
69+
subsampling_factor: ${model.encoder.subsampling_factor}
70+
71+
test_ds:
72+
manifest_filepath: null
73+
is_tarred: False
74+
tarred_audio_filepaths: null
75+
sample_rate: 16000
76+
num_spks: ${model.max_num_of_spks}
77+
session_len_sec: 90 # Maximum session length in seconds
78+
soft_label_thres: 0.5
79+
soft_targets: False
80+
labels: null
81+
batch_size: ${batch_size}
82+
shuffle: False
83+
seq_eval_mode: True
84+
num_workers: ${num_workers}
85+
validation_mode: True
86+
# lhotse config
87+
use_lhotse: False
88+
use_bucketing: False
89+
drop_last: False
90+
pin_memory: True
91+
window_stride: ${model.preprocessor.window_stride}
92+
subsampling_factor: ${model.encoder.subsampling_factor}
93+
94+
preprocessor:
95+
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
96+
normalize: "NA"
97+
window_size: 0.025
98+
sample_rate: ${model.sample_rate}
99+
window_stride: 0.01
100+
window: "hann"
101+
features: 128
102+
n_fft: 512
103+
frame_splicing: 1
104+
dither: 0.00001
105+
106+
sortformer_modules:
107+
_target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules
108+
num_spks: ${model.max_num_of_spks} # Maximum number of speakers the model can handle
109+
dropout_rate: 0.5 # Dropout rate
110+
fc_d_model: ${model.model_defaults.fc_d_model} # Hidden dimension size for Fast Conformer encoder
111+
tf_d_model: ${model.model_defaults.tf_d_model} # Hidden dimension size for Transformer encoder
112+
# Streaming mode parameters
113+
spkcache_len: 188 # Length of speaker cache buffer (total number of frames for all speakers)
114+
fifo_len: 0 # Length of FIFO buffer for streaming processing (0 = disabled)
115+
chunk_len: 188 # Number of frames processed in each streaming chunk
116+
spkcache_update_period: 1 # Speaker cache update period in frames
117+
chunk_left_context: 1 # Number of previous frames for each streaming chunk
118+
chunk_right_context: 1 # Number of future frames for each streaming chunk
119+
# Speaker cache update parameters
120+
spkcache_sil_frames_per_spk: 3 # Number of silence frames allocated per speaker in the speaker cache
121+
scores_add_rnd: 0 # Standard deviation of random noise added to scores in speaker cache update (training only)
122+
pred_score_threshold: 0.25 # Probability threshold for internal scores processing in speaker cache update
123+
max_index: 99999 # Maximum allowed index value for internal processing in speaker cache update
124+
scores_boost_latest: 0.05 # Gain for scores for recently added frames in speaker cache update
125+
sil_threshold: 0.2 # Threshold for determining silence frames to calculate average silence embedding
126+
strong_boost_rate: 0.75 # Rate determining number of frames per speaker that receive strong score boosting
127+
weak_boost_rate: 1.5 # Rate determining number of frames per speaker that receive weak score boosting
128+
min_pos_scores_rate: 0.5 # Rate threshold for dropping overlapping frames when enough non-overlapping exist
129+
# Self-attention parameters (training only)
130+
causal_attn_rate: 0.5 # Proportion of batches that use self-attention with limited right context
131+
causal_attn_rc: 7 # Right context size for self-attention with limited right context
132+
133+
encoder:
134+
_target_: nemo.collections.asr.modules.ConformerEncoder
135+
feat_in: ${model.preprocessor.features}
136+
feat_out: -1
137+
n_layers: 17
138+
d_model: ${model.model_defaults.fc_d_model}
139+
140+
# Sub-sampling parameters
141+
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
142+
subsampling_factor: 8 # must be power of 2 for striding and vggnet
143+
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
144+
causal_downsampling: false
145+
# Feed forward module's params
146+
ff_expansion_factor: 4
147+
# Multi-headed Attention Module's params
148+
self_attention_model: rel_pos # rel_pos or abs_pos
149+
n_heads: 8 # may need to be lower for smaller d_models
150+
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
151+
att_context_size: [-1, -1] # -1 means unlimited context
152+
att_context_style: regular # regular or chunked_limited
153+
xscaling: true # scales up the input embeddings by sqrt(d_model)
154+
untie_biases: true # unties the biases of the TransformerXL layers
155+
pos_emb_max_len: 5000
156+
# Convolution module's params
157+
conv_kernel_size: 9
158+
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
159+
conv_context_size: null
160+
# Regularization
161+
dropout: 0.1 # The dropout used in most of the Conformer Modules
162+
dropout_pre_encoder: 0.1 # The dropout used before the encoder
163+
dropout_emb: 0.0 # The dropout used for embeddings
164+
dropout_att: 0.1 # The dropout for multi-headed attention modules
165+
# Set to non-zero to enable stochastic depth
166+
stochastic_depth_drop_prob: 0.0
167+
stochastic_depth_mode: linear # linear or uniform
168+
stochastic_depth_start_layer: 1
169+
170+
transformer_encoder:
171+
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
172+
num_layers: 18
173+
hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads
174+
inner_size: 768
175+
num_attention_heads: 8
176+
attn_score_dropout: 0.5
177+
attn_layer_dropout: 0.5
178+
ffn_dropout: 0.5
179+
hidden_act: relu
180+
pre_ln: False
181+
pre_ln_final_layer_norm: True
182+
183+
loss:
184+
_target_: nemo.collections.asr.losses.bce_loss.BCELoss
185+
weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5])
186+
reduction: mean
187+
188+
lr: 0.0001
189+
optim:
190+
name: adamw
191+
lr: ${model.lr}
192+
# optimizer arguments
193+
betas: [0.9, 0.98]
194+
weight_decay: 1e-3
195+
196+
sched:
197+
name: InverseSquareRootAnnealing
198+
warmup_steps: 500
199+
warmup_ratio: null
200+
min_lr: 1e-06
201+
202+
trainer:
203+
devices: 1 # number of gpus (devices)
204+
accelerator: gpu
205+
max_epochs: 800
206+
max_steps: -1 # computed at runtime if not set
207+
num_nodes: 1
208+
strategy: ddp_find_unused_parameters_true # Could be "ddp"
209+
accumulate_grad_batches: 1
210+
deterministic: True
211+
enable_checkpointing: False
212+
logger: False
213+
log_every_n_steps: 1 # Interval of logging.
214+
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
215+
216+
exp_manager:
217+
use_datetime_version: False
218+
exp_dir: null
219+
name: ${name}
220+
resume_if_exists: True
221+
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
222+
resume_ignore_no_checkpoint: True
223+
create_tensorboard_logger: True
224+
create_checkpoint_callback: True
225+
create_wandb_logger: False
226+
checkpoint_callback_params:
227+
monitor: "val_f1_acc"
228+
mode: "max"
229+
save_top_k: 9
230+
every_n_epochs: 1
231+
wandb_logger_kwargs:
232+
resume: True
233+
name: null
234+
project: null
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Postprocessing parameters for timestamp outputs from speaker diarization models.
2+
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
3+
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
4+
# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh
5+
parameters:
6+
onset: 0.641 # Onset threshold for detecting the beginning and end of a speech
7+
offset: 0.561 # Offset threshold for detecting the end of a speech
8+
pad_onset: 0.229 # Adding durations before each speech segment
9+
pad_offset: 0.079 # Adding durations after each speech segment
10+
min_duration_on: 0.511 # Threshold for small non-speech deletion
11+
min_duration_off: 0.296 # Threshold for short speech segment deletion
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Postprocessing parameters for timestamp outputs from speaker diarization models.
2+
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
3+
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
4+
# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477).
5+
parameters:
6+
onset: 0.56 # Onset threshold for detecting the beginning and end of a speech
7+
offset: 1.0 # Offset threshold for detecting the end of a speech
8+
pad_onset: 0.063 # Adding durations before each speech segment
9+
pad_offset: 0.002 # Adding durations after each speech segment
10+
min_duration_on: 0.007 # Threshold for small non-speech deletion
11+
min_duration_off: 0.151 # Threshold for short speech segment deletion

examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ class DiarizationConfig:
9292
log: bool = False # If True, log will be printed
9393

9494
use_lhotse: bool = True
95-
batch_duration: int = 33000
95+
batch_duration: int = 100000
9696

9797
# Eval Settings: (0.25, False) should be default setting for sortformer eval.
9898
collar: float = 0.25 # Collar in seconds for DER calculation
9999
ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments
100100

101101
# Streaming diarization configs
102102
spkcache_len: int = 188
103-
spkcache_refresh_rate: int = 144
103+
spkcache_update_period: int = 144
104104
fifo_len: int = 188
105105
chunk_len: int = 6
106106
chunk_left_context: int = 1
@@ -386,7 +386,10 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
386386
diar_model.sortformer_modules.chunk_right_context = cfg.chunk_right_context
387387
diar_model.sortformer_modules.fifo_len = cfg.fifo_len
388388
diar_model.sortformer_modules.log = cfg.log
389-
diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate
389+
diar_model.sortformer_modules.spkcache_update_period = cfg.spkcache_update_period
390+
391+
# Check if the streaming parameters are valid
392+
diar_model.sortformer_modules._check_streaming_parameters()
390393

391394
postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml)
392395
tensor_path, model_id, tensor_filename = get_tensor_path(cfg)

0 commit comments

Comments
 (0)