Skip to content

Commit 0a5093a

Browse files
tango4jguyueh1
authored andcommitted
Streaming Sortformer release PR04: Adding functional tests for streaming sortformer (NVIDIA-NeMo#14435)
* Adding functional tests for streaming sortformer Signed-off-by: taejinp <tango4j@gmail.com> * Fixing the unchanged part from offline sortformer Signed-off-by: taejinp <tango4j@gmail.com> * Fixing lightning imports for seeding Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: tango4j <tango4j@users.noreply.github.com> --------- Signed-off-by: taejinp <tango4j@gmail.com> Signed-off-by: tango4j <tango4j@users.noreply.github.com> Co-authored-by: tango4j <tango4j@users.noreply.github.com> Signed-off-by: Guyue Huang <guyueh@nvidia.com>
1 parent 4ec25c4 commit 0a5093a

4 files changed

Lines changed: 94 additions & 1 deletion

File tree

examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
import lightning.pytorch as pl
16+
from lightning.pytorch import seed_everything
1617
from omegaconf import OmegaConf
17-
from pytorch_lightning import seed_everything
1818

1919
from nemo.collections.asr.models import SortformerEncLabelModel
2020
from nemo.core.config import hydra_runner
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import lightning.pytorch as pl
16+
from lightning.pytorch import seed_everything
17+
from omegaconf import OmegaConf
18+
19+
from nemo.collections.asr.models import SortformerEncLabelModel
20+
from nemo.core.config import hydra_runner
21+
from nemo.utils import logging
22+
from nemo.utils.exp_manager import exp_manager
23+
24+
"""
25+
Example training session (single node training)
26+
27+
python ./streaming_sortformer_diar_train.py --config-path='../conf/neural_diarizer' \
28+
--config-name='streaming_sortformer_diarizer_4spk-v2.yaml' \
29+
trainer.devices=1 \
30+
model.train_ds.manifest_filepath="<train_manifest_path>" \
31+
model.validation_ds.manifest_filepath="<dev_manifest_path>" \
32+
exp_manager.name='sample_train' \
33+
exp_manager.exp_dir='./streaming_sortformer_diar_train'
34+
"""
35+
36+
seed_everything(42)
37+
38+
39+
@hydra_runner(config_path="../conf/neural_diarizer", config_name="streaming_sortformer_diarizer_4spk-v2.yaml")
40+
def main(cfg):
41+
"""Main function for training the sortformer diarizer model."""
42+
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
43+
trainer = pl.Trainer(**cfg.trainer)
44+
exp_manager(trainer, cfg.get("exp_manager", None))
45+
sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer)
46+
sortformer_model.maybe_init_from_pretrained_checkpoint(cfg)
47+
trainer.fit(sortformer_model)
48+
49+
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
50+
if sortformer_model.prepare_test(trainer):
51+
trainer.test(sortformer_model)
52+
53+
54+
if __name__ == '__main__':
55+
main()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \
15+
model_path=/home/TestData/an4_diarizer/diar_streaming_sortformer_4spk-v2-tiny.nemo \
16+
dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
17+
batch_size=1
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/speaker_tasks/diarization/neural_diarizer/streaming_sortformer_diar_train.py \
15+
trainer.devices="[0]" \
16+
batch_size=3 \
17+
model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \
18+
model.test_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
19+
model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
20+
exp_manager.exp_dir=/tmp/speaker_diarization_results \
21+
+trainer.fast_dev_run=True

0 commit comments

Comments
 (0)